]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : Q4 cleanup - remove 4-bit dot product code (#1061)
authorStephan Walter <redacted>
Wed, 19 Apr 2023 16:06:37 +0000 (16:06 +0000)
committerGitHub <redacted>
Wed, 19 Apr 2023 16:06:37 +0000 (19:06 +0300)
* Q4 cleanup

* Remove unused AVX512 Q4_0 code

CMakeLists.txt
Makefile
ggml.c

index 8eadea4fd4862709cb823605b605366dd59337a7..d7aa051da4ac641dac653326837e6a18fdca630d 100644 (file)
@@ -174,7 +174,6 @@ if (LLAMA_ALL_WARNINGS)
             -Wshadow
             -Wstrict-prototypes
             -Wpointer-arith
-            -Wno-unused-function
         )
         set(cxx_flags
             -Wall
index deb0d00090f5aff5db9f24d81dfc6bb14161ddf3..d9a2d836babeed44cf1c4da47f588f3574c497cd 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -36,7 +36,7 @@ CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC
 LDFLAGS  =
 
 # warnings
-CFLAGS   += -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -Wno-unused-function
+CFLAGS   += -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith
 CXXFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function -Wno-multichar
 
 # OS specific
diff --git a/ggml.c b/ggml.c
index 13c1548fee895a37cb416fc1e9923b7e32b4f14a..7728794743c711acbc0043746541fb3f76dd81a0 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -1562,7 +1562,13 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
         .quantize_row_q_dot       = quantize_row_q8_0,
         .vec_dot_q                = ggml_vec_dot_q4_2_q8_0,
     },
-    // TODO: GGML_TYPE_Q8_0
+    [GGML_TYPE_Q8_0] = {
+        .dequantize_row_q         = NULL,   // TODO
+        .quantize_row_q           = quantize_row_q8_0,
+        .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_0_reference,
+        .quantize_row_q_dot       = quantize_row_q8_0,
+        .vec_dot_q                = NULL,   // TODO
+    },
 };
 
 // For internal test use
@@ -2128,191 +2134,6 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
     *s = sumf;
 }
 
-#if __AVX512F__ && QK4_0 == 32
-static inline __m512i bytes_from_q4_0_twoblocks_avx512( const __m512i blocks ) {
-    // The 64 bytes of `blocks` contain two consecutive Q4_0 blocks loaded from memory:
-    // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
-    // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
-    // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
-    // |                                                                        :. =_ () [] <> () Zz Yy|
-    // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
-    // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
-    // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
-    // |Xx Ww Vv Uu Tt Ss Rr Qq             Pp Oo Nn Mm Ll Kk Jj Ii Hh Gg Ff Ee Dd Cc Bb Aa            |
-    // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+
-    //
-    // Bytes 04..19 (block #0) and 24..39 (block #1) both contain 32 nibbles (4-bit unsigned integers).
-    // We have exactly 64 nibbles, so we want to place each nibble into a separate byte.
-    // Bytes 00..03 and 20..23 contain scales, which are irrelevant to this function.
-    // Bytes 40..63 are masked when loading the data, so they are zeroed out.
-#ifdef __AVX512VBMI__
-    const __m512i byte_perm = _mm512_set_epi8(
-        39, 38, 39, 38, 37, 36, 37, 36, 35, 34, 35, 34, 33, 32, 33, 32,
-        31, 30, 31, 30, 29, 28, 29, 28, 27, 26, 27, 26, 25, 24, 25, 24,
-        19, 18, 19, 18, 17, 16, 17, 16, 15, 14, 15, 14, 13, 12, 13, 12,
-        11, 10, 11, 10,  9,  8,  9,  8,  7,  6,  7,  6,  5,  4,  5,  4
-    );
-    const __m512i permuted = _mm512_permutexvar_epi8( byte_perm, blocks );
-    // After applying VPERMB, `permuted` looks like this:
-    // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
-    // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
-    // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
-    // |:. =_ :. =_ () [] () [] <> () <> () Zz Yy Zz Yy Xx Ww Xx Ww Vv Uu Vv Uu Tt Ss Tt Ss Rr Qq Rr Qq|
-    // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
-    // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
-    // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
-    // |Pp Oo Pp Oo Nn Mm Nn Mm Ll Kk Ll Kk Jj Ii Jj Ii Hh Gg Hh Gg Ff Ee Ff Ee Dd Cc Dd Cc Bb Aa Bb Aa|
-    // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
-#else
-    const __m512i word_perm = _mm512_set_epi16(
-        19, 19, 18, 18, 17, 17, 16, 16, 15, 15, 14, 14, 13, 13, 12, 12,
-         9,  9,  8,  8,  7,  7,  6,  6,  5,  5,  4,  4,  3,  3,  2,  2
-    );
-    const __m512i permuted = _mm512_permutexvar_epi16( word_perm, blocks );
-    // This is the fallback path for CPUs that don't support VPERMB. Since we permute 16-bit groups only,
-    // VPERMB can be replaced with VPERMW. We could always use VPERMW, but at least on Tiger Lake and
-    // Ice Lake VPERMW followed by a right shift is quite noticeably slower than VPERMB.
-#endif
-
-    // Shift every odd-numbered 16-bit group to the right by 4 bits.
-    const __mmask32 shift_mask = 0xaaaaaaaa;
-    const __m512i shifted = _mm512_mask_srai_epi16( permuted, shift_mask, permuted, 4 );
-    // After applying VPSRAW, `shifted` looks like this (the "empty" nibbles are filled with zeroes):
-    // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
-    // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32
-    // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
-    // | : .= :. =_  ( )[ () []  < >( <> ()  Z zY Zz Yy  X xW Xx Ww  V vU Vv Uu  T tS Tt Ss  R rQ Rr Qq
-    // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
-    // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
-    // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
-    // | P pO Pp Oo  N nM Nn Mm  L lK Ll Kk  J jI Jj Ii  H hG Hh Gg  F fE Ff Ee  D dC Dd Cc  B bA Bb Aa|
-    // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
-
-    // Now we just need to zero out the higher nibble in each byte, and we're done.
-    const __m512i low_nibble_mask = _mm512_set1_epi8( 0xf );
-    return _mm512_and_si512( low_nibble_mask, shifted );
-    // The final result looks like this:
-    // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
-    // |63 62 61 60 59 58 57 56 55 54 53 52 51 50 49 48 47 46 45 44 43 42 41 40 39 38 37 36 35 34 33 32|
-    // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
-    // | :  =  .  _  (  [  )  ]  <  (  >  )  Z  Y  z  y  X  W  x  w  V  U  v  u  T  S  t  s  R  Q  r  q|
-    // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
-    // |31 30 29 28 27 26 25 24 23 22 21 20 19 18 17 16 15 14 13 12 11 10 09 08 07 06 05 04 03 02 01 00|
-    // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
-    // | P  O  p  o  N  M  n  m  L  K  l  k  J  I  j  i  H  G  h  g  F  E  f  e  D  C  d  c  B  A  b  a|
-    // +-----------+-----------+-----------+-----------+-----------+-----------+-----------+-----------+
-}
-
-static inline __m512 dot_q4_0_twoblocks_avx512(
-    __m512 acc,
-    const block_q4_0 * restrict x,
-    const block_q4_0 * restrict y,
-    int i
-) {
-    // A pair of Q4_0 blocks spans 40 bytes, while an AVX-512 register has 64. The remaining 24 bytes
-    // can potentially be unaddressable, so we make sure to mask them out before the load, even though
-    // we don't use them at all. This might hurt the performance slightly, since the compiler is forced
-    // to use e.g. `VMOVDQU64 REG, MASK, [ADDR] + VPERMB ..., REG` instead of just `VPERMB ..., [ADDR]`.
-    const __mmask8 load_mask = 0x1f;
-    const __m512i blocks_0 = _mm512_maskz_loadu_epi64( load_mask, &x[i] );
-    const __m512i blocks_1 = _mm512_maskz_loadu_epi64( load_mask, &y[i] );
-
-    // We want to multiply the scales, so we interpret both registers as 16 32-bit floats:
-    // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
-    // | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
-    // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
-    // blocks_0_float
-    // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
-    // |    |    |    |    |    |    | xx | xx | xx | xx |  B | xx | xx | xx | xx |  A |
-    // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
-    // blocks_1_float
-    // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
-    // |    |    |    |    |    |    | xx | xx | xx | xx |  D | xx | xx | xx | xx |  C |
-    // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
-    const __m512 blocks_0_float = _mm512_castsi512_ps( blocks_0 );
-    const __m512 blocks_1_float = _mm512_castsi512_ps( blocks_1 );
-    // We absolutely shouldn't touch the floats marked with `xx`: they contain some
-    // random data, which might very well underflow. At least on Intel, this leads
-    // to a huge penalty that can't be ignored (easily 100x or more) unless you
-    // compile your code with something like `-ffast-math` to enable FTZ/DAZ flags.
-    // (and ggml can't assume that you do)...
-    const __mmask16 scale_mul_mask = 0x21;
-#ifdef __clang__
-    // ...however, clang decides to optimize the multiplication mask away:
-    // https://godbolt.org/z/P8PqdsfvW
-    // gcc and MSVC do the sane thing. This horrible workaround forces clang to emit the mask.
-    __m512i scales;
-    __asm__(
-        "vmulps %1, %2, %0%{%3%}"
-        : "=v" ( scales )
-        : "vm" ( blocks_0_float ), "v" ( blocks_1_float ), "Yk" ( scale_mul_mask )
-    );
-#else
-    const __m512 scales = _mm512_maskz_mul_ps( scale_mul_mask, blocks_0_float, blocks_1_float );
-#endif
-    const __m512i scale_perm = _mm512_set_epi32(
-        5, 5, 5, 5, 5, 5, 5, 5,
-        0, 0, 0, 0, 0, 0, 0, 0
-    );
-    const __m512 permuted_scales = _mm512_permutexvar_ps( scale_perm, scales );
-    // After VMULPS and VPERMPS, `permuted_scales` looks like this:
-    // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
-    // | 15 | 14 | 13 | 12 | 11 | 10 | 09 | 08 | 07 | 06 | 05 | 04 | 03 | 02 | 01 | 00 |
-    // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
-    // | B*D| B*D| B*D| B*D| B*D| B*D| B*D| B*D| A*C| A*C| A*C| A*C| A*C| A*C| A*C| A*C|
-    // +----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+----+
-
-    const __m512i bytes_0 = bytes_from_q4_0_twoblocks_avx512( blocks_0 );
-    const __m512i bytes_1 = bytes_from_q4_0_twoblocks_avx512( blocks_1 );
-
-    // Now we want to compute dot products of 4-element byte vectors and store them in
-    // 32-bit integers. That is (only one 4-element vector is shown for clarity):
-    //     +----+----+----+----+
-    // ... | 03 | 02 | 01 | 00 |
-    //     +----+----+----+----+
-    // bytes_0
-    //     +----+----+----+----+
-    // ... |  D |  C |  B |  A |
-    //     +----+----+----+----+
-    // bytes_1
-    //     +----+----+----+----+
-    // ... |  H |  G |  F |  E |
-    //     +----+----+----+----+
-    // final_res_int
-    //     +----+----+----+----+
-    // ... |  A*E+B*F+C*G+D*H  |
-    //     +----+----+----+----+
-    const __m512i plus_8 = _mm512_set1_epi8( 8 );
-    const __m512i bytes_1_minus_8 = _mm512_sub_epi8( bytes_1, plus_8 );
-
-#ifdef __AVX512VNNI__
-    // We have VPDPBUSDS in AVX512-VNNI, which does exactly what we want, but with a catch:
-    // the *left* operand is supposed to be unsigned, while Q4_0 quantization subtracts 8
-    // from each nibble, so they can be negative. So, instead of `(bytes_0 - 8) * (bytes_1 - 8)`,
-    // we compute `bytes_0 * (bytes_1 - 8) + bytes_1 * (-8) + 64`. VPDPBUSDS uses an accumulator,
-    // which means we only need 2 instructions.
-    const __m512i dot_init = _mm512_set1_epi32( 4 * 64 );
-    const __m512i minus_8 = _mm512_set1_epi8( -8 );
-    const __m512i prod_0 = _mm512_dpbusds_epi32( dot_init, bytes_1, minus_8 );
-    const __m512i final_res_int = _mm512_dpbusds_epi32( prod_0, bytes_0, bytes_1_minus_8 );
-#else
-    // As a fallback, we have VPMADDUBSW in AVX512-BW, which uses 16-bit products instead of 32-bit ones.
-    // It has the same catch as VPDPBUSDS: the left operand should be unsigned.
-    // This is essentially the AVX-512 version of the AVX-2 trick used by GH user Const-me
-    //   ref: https://gist.github.com/Const-me/4d30e1fc767ab314596e16e90f53b6f4#file-matmultest-cpp-L119
-    const __m512i one = _mm512_set1_epi16( 1 );
-    const __m512i prod_0 = _mm512_maddubs_epi16( bytes_0, bytes_1_minus_8 );
-    const __m512i prod_1 = _mm512_maddubs_epi16( plus_8, bytes_1_minus_8 );
-    const __m512i diff = _mm512_sub_epi16( prod_0, prod_1 );
-    const __m512i final_res_int = _mm512_madd_epi16( diff, one );
-#endif
-
-    // Finally, we multiply the permuted scales and the 32-bit dot products, then accumulate.
-    const __m512 final_res_float = _mm512_cvtepi32_ps( final_res_int );
-    return _mm512_fmadd_ps( permuted_scales, final_res_float, acc );
-}
-#endif
-
 inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
     ggml_float sumf = 0.0;
 
@@ -2349,352 +2170,6 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
     *s = sumf;
 }
 
-static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
-    const int nb = n / QK4_0;
-
-    assert(n % QK4_0 == 0);
-    assert(nb % 2 == 0);
-
-    const block_q4_0 * restrict x = vx;
-    const block_q4_0 * restrict y = vy;
-
-    float sumf = 0.0;
-
-#if defined(__ARM_NEON)
-    float sum0 = 0.0f;
-    float sum1 = 0.0f;
-
-    for (int i = 0; i < nb; i += 2) {
-        const block_q4_0 * restrict x0 = &x[i + 0];
-        const block_q4_0 * restrict y0 = &y[i + 0];
-        const block_q4_0 * restrict x1 = &x[i + 1];
-        const block_q4_0 * restrict y1 = &y[i + 1];
-
-        const uint8x16_t m4b = vdupq_n_u8(0xf);
-        const int8x16_t  s8b = vdupq_n_s8(0x8);
-
-        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
-        const uint8x16_t v1_0 = vld1q_u8(y0->qs);
-        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
-        const uint8x16_t v1_1 = vld1q_u8(y1->qs);
-
-        // 4-bit -> 8-bit
-        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
-        const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b));
-        const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
-        const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4));
-
-        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
-        const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b));
-        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
-        const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));
-
-        // sub 8
-        const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
-        const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
-        const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
-        const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
-
-        const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
-        const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
-        const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
-        const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
-
-#if defined(__ARM_FEATURE_DOTPROD)
-        // dot product into int32x4_t
-        int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
-        int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
-
-        p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
-        p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
-
-        sum0 += x0->d*y0->d*vaddvq_s32(p_0);
-        sum1 += x1->d*y1->d*vaddvq_s32(p_1);
-#else
-        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
-        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
-        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
-        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
-
-        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
-        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
-        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
-        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
-
-        const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h);
-        const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h);
-
-        const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h);
-        const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h);
-
-        const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
-        const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
-
-        sum0 += x0->d*y0->d*vaddvq_s16(p_0);
-        sum1 += x1->d*y1->d*vaddvq_s16(p_1);
-#endif
-    }
-
-    sumf = sum0 + sum1;
-#elif defined(__AVX512F__)
-    // Initialize accumulator with zeros
-    __m512 acc0 = _mm512_setzero_ps();
-    __m512 acc1 = _mm512_setzero_ps();
-
-    const int superblock_size = 16;
-
-    const int superblock_count = nb / superblock_size;
-
-    for (int superblock_ix = 0; superblock_ix < superblock_count; superblock_ix += 1) {
-        int i = superblock_ix * superblock_size;
-
-        acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+0 );
-        acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+2 );
-        acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+4 );
-        acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+6 );
-        acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+8 );
-        acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+10 );
-        acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i+12 );
-        acc1 = dot_q4_0_twoblocks_avx512( acc1, x, y, i+14 );
-    }
-
-    // Remainders
-    for (int i = superblock_count * superblock_size; i < nb; i += 2) {
-        acc0 = dot_q4_0_twoblocks_avx512( acc0, x, y, i );
-    }
-
-    // Horizontal sum of all lanes of the accumulator
-    sumf = _mm512_reduce_add_ps( acc0 ) + _mm512_reduce_add_ps( acc1 );
-#elif defined(__AVX2__)
-    // Initialize accumulator with zeros
-    __m256 acc = _mm256_setzero_ps();
-
-    /* Prepare the constants we will need during execution */
-    const __m256i lowMask = _mm256_set1_epi8( 0xF );
-    const __m256i offset_8 = _mm256_set1_epi16( 8 );
-
-#define UNROLL_COUNT 8
-    // make sure we only unroll multiples of the block count
-    assert(nb % UNROLL_COUNT == 0);
-
-    // Main loop
-    for (int i = 0; i < nb; i+=UNROLL_COUNT) {
-        // This loop will be unrolled by the compiler
-        for (int u=0;u<UNROLL_COUNT;u++)  {
-            /* Compute combined scale for the block */
-            const __m256 scale = _mm256_mul_ps(
-                    _mm256_broadcast_ss( &x[i+u].d ),
-                    _mm256_broadcast_ss( &y[i+u].d ) );
-
-            /* get input from x
-               Input: 32 Nibbles (16 bytes) at *x[i+u]
-               Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
-
-            /* Load 16 bytes from memory */
-            const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs);
-            /* Expand bytes into uint16_t values */
-            const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);
-            /* Unpack values into individual bytes */
-            __m256i x_low_q = _mm256_and_si256( lowMask, bytes_x );
-            const __m256i pre_shift_x_high_q = _mm256_andnot_si256( lowMask, bytes_x );
-            __m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );
-            /* Now we have two vectors with bytes in [ 0 .. 15 ] interval.  Offset them into [ -8 .. +7 ] interval.  */
-            x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );
-            x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );
-
-            /* get input from y
-               Input: 32 Nibbles (16 bytes) at *y[i+u]
-               Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
-
-            /* Load 16 bytes from memory */
-            const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs);
-            /* Expand bytes into uint16_t values */
-            const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);
-            /* Unpack values into individual bytes */
-            const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );
-            __m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );
-            __m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );
-            /* Now we have two vectors with bytes in [ 0 .. 15 ] interval.  Offset them into [ -8 .. +7 ] interval.  */
-            y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );
-            y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );
-
-            /* Compute products of int16_t integers, add pairwise, store as int32_t */
-            __m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );
-            __m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );
-
-            /* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
-            __m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );
-
-            /* Convert to vectore of 8 int32_t to 8 floats */
-            __m256 q = _mm256_cvtepi32_ps( xy_q );
-
-            /* Multiply q with scale and accumulate */
-            acc = _mm256_fmadd_ps( scale, q, acc );
-        }
-    }
-
-    // Return horizontal sum of the acc vector
-    __m128 res = _mm256_extractf128_ps( acc, 1 );
-    res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
-    res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
-    res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
-
-    sumf = _mm_cvtss_f32( res );
-#elif defined(__AVX__)
-    // Initialize accumulator with zeros
-    __m256 acc = _mm256_setzero_ps();
-
-    // Main loop
-    for (int i = 0; i < nb; ++i) {
-        // Compute combined scale for the block
-        const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
-
-        __m128i i32[2];
-        for (int j = 0; j < 2; ++j) {
-            // Load 8 bytes, and unpack 4 bit fields into bytes, making 16 bytes
-            __m128i bx = bytesFromNibbles( x[i].qs + 8*j );
-            __m128i by = bytesFromNibbles( y[i].qs + 8*j );
-
-            // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
-            const __m128i off = _mm_set1_epi8( 8 );
-            bx = _mm_sub_epi8( bx, off );
-            by = _mm_sub_epi8( by, off );
-
-            // Get absolute values of x vectors
-            const __m128i ax = _mm_sign_epi8(bx, bx);
-
-            // Sign the values of the y vectors
-            const __m128i sy = _mm_sign_epi8(by, bx);
-
-            // Perform multiplication and create 16-bit values
-            const __m128i dot = _mm_maddubs_epi16(ax, sy);
-
-            const __m128i ones = _mm_set1_epi16(1);
-            i32[j] = _mm_madd_epi16(ones, dot);
-        }
-
-        // Convert int32_t to float
-        __m256 p = _mm256_cvtepi32_ps( _mm256_set_m128i( i32[0], i32[1] ));
-        // Apply the scale, and accumulate
-        acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
-    }
-
-    // Return horizontal sum of the acc vector
-    __m128 res = _mm256_extractf128_ps( acc, 1 );
-    res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
-    res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
-    res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
-
-    sumf = _mm_cvtss_f32( res );
-#elif defined(__wasm_simd128__)
-    // wasm simd
-    float sum0 = 0.0f;
-    float sum1 = 0.0f;
-
-    for (int i = 0; i < nb; i += 2) {
-        const block_q4_0 * restrict x0 = &x[i + 0];
-        const block_q4_0 * restrict y0 = &y[i + 0];
-        const block_q4_0 * restrict x1 = &x[i + 1];
-        const block_q4_0 * restrict y1 = &y[i + 1];
-
-        const v128_t m4b = wasm_u8x16_splat(0xf);
-        const v128_t s8b = wasm_i8x16_splat(0x8);
-
-        const v128_t v0_0 = wasm_v128_load(x0->qs);
-        const v128_t v0_1 = wasm_v128_load(y0->qs);
-        const v128_t v1_0 = wasm_v128_load(x1->qs);
-        const v128_t v1_1 = wasm_v128_load(y1->qs);
-
-        // 4-bit -> 8-bit
-        const v128_t v0_0l = wasm_v128_and(v0_0, m4b);
-        const v128_t v1_0l = wasm_v128_and(v1_0, m4b);
-
-        const v128_t v0_0h = wasm_u8x16_shr(v0_0, 4);
-        const v128_t v1_0h = wasm_u8x16_shr(v1_0, 4);
-
-        const v128_t v0_1l = wasm_v128_and(v0_1, m4b);
-        const v128_t v1_1l = wasm_v128_and(v1_1, m4b);
-
-        const v128_t v0_1h = wasm_u8x16_shr(v0_1, 4);
-        const v128_t v1_1h = wasm_u8x16_shr(v1_1, 4);
-
-        // sub 8
-        const v128_t v0_0ls = wasm_i8x16_sub(v0_0l, s8b);
-        const v128_t v1_0ls = wasm_i8x16_sub(v1_0l, s8b);
-
-        const v128_t v0_0hs = wasm_i8x16_sub(v0_0h, s8b);
-        const v128_t v1_0hs = wasm_i8x16_sub(v1_0h, s8b);
-
-        const v128_t v0_1ls = wasm_i8x16_sub(v0_1l, s8b);
-        const v128_t v1_1ls = wasm_i8x16_sub(v1_1l, s8b);
-
-        const v128_t v0_1hs = wasm_i8x16_sub(v0_1h, s8b);
-        const v128_t v1_1hs = wasm_i8x16_sub(v1_1h, s8b);
-
-        // dot product into int16x8_t
-        const v128_t pl0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0ls), wasm_i16x8_extend_low_i8x16(v1_0ls));
-        const v128_t pl0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0ls), wasm_i16x8_extend_high_i8x16(v1_0ls));
-
-        const v128_t ph0l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_0hs), wasm_i16x8_extend_low_i8x16(v1_0hs));
-        const v128_t ph0h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_0hs), wasm_i16x8_extend_high_i8x16(v1_0hs));
-
-        const v128_t pl1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1ls), wasm_i16x8_extend_low_i8x16(v1_1ls));
-        const v128_t pl1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1ls), wasm_i16x8_extend_high_i8x16(v1_1ls));
-
-        const v128_t ph1l = wasm_i16x8_mul(wasm_i16x8_extend_low_i8x16(v0_1hs), wasm_i16x8_extend_low_i8x16(v1_1hs));
-        const v128_t ph1h = wasm_i16x8_mul(wasm_i16x8_extend_high_i8x16(v0_1hs), wasm_i16x8_extend_high_i8x16(v1_1hs));
-
-        const v128_t pl_0 = wasm_i16x8_add(pl0l, pl0h);
-        const v128_t ph_0 = wasm_i16x8_add(ph0l, ph0h);
-
-        const v128_t pl_1 = wasm_i16x8_add(pl1l, pl1h);
-        const v128_t ph_1 = wasm_i16x8_add(ph1l, ph1h);
-
-        const v128_t p_0 = wasm_i16x8_add(pl_0, ph_0);
-        const v128_t p_1 = wasm_i16x8_add(pl_1, ph_1);
-
-        sum0 += x0->d * y0->d * (
-                wasm_i16x8_extract_lane(p_0, 0) + wasm_i16x8_extract_lane(p_0, 1) +
-                wasm_i16x8_extract_lane(p_0, 2) + wasm_i16x8_extract_lane(p_0, 3) +
-                wasm_i16x8_extract_lane(p_0, 4) + wasm_i16x8_extract_lane(p_0, 5) +
-                wasm_i16x8_extract_lane(p_0, 6) + wasm_i16x8_extract_lane(p_0, 7));
-        sum1 += x1->d * y1->d * (
-                wasm_i16x8_extract_lane(p_1, 0) + wasm_i16x8_extract_lane(p_1, 1) +
-                wasm_i16x8_extract_lane(p_1, 2) + wasm_i16x8_extract_lane(p_1, 3) +
-                wasm_i16x8_extract_lane(p_1, 4) + wasm_i16x8_extract_lane(p_1, 5) +
-                wasm_i16x8_extract_lane(p_1, 6) + wasm_i16x8_extract_lane(p_1, 7));
-    }
-
-    sumf = sum0 + sum1;
-#else
-    // scalar
-    for (int i = 0; i < nb; i++) {
-        const float d0 = x[i].d;
-        const float d1 = y[i].d;
-
-        const uint8_t * restrict p0 = x[i].qs;
-        const uint8_t * restrict p1 = y[i].qs;
-
-        int sumi = 0;
-        for (int j = 0; j < QK4_0/2; j++) {
-            const uint8_t v0 = p0[j];
-            const uint8_t v1 = p1[j];
-
-            const int i0 = (v0 & 0xf) - 8;
-            const int i1 = (v0 >> 4)  - 8;
-
-            const int i2 = (v1 & 0xf) - 8;
-            const int i3 = (v1 >> 4)  - 8;
-
-            sumi += i0*i2 + i1*i3;
-        }
-        sumf += d0 * d1 * sumi;
-    }
-#endif
-
-    *s = sumf;
-}
-
 static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
     const int nb = n / QK4_1;
 
@@ -11064,7 +10539,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
 #endif
                         } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
                             cur = 0;
-                        } else if (quantize_fns[node->src0->type].vec_dot_q && node->src1->type == GGML_TYPE_F32) {
+                        } else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) {
 #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
                             if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
                                 node->n_tasks = 1;