]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : use 8-bit precision for Q4_1 intermediate results (#1047)
authorGeorgi Gerganov <redacted>
Wed, 19 Apr 2023 17:10:08 +0000 (20:10 +0300)
committerGitHub <redacted>
Wed, 19 Apr 2023 17:10:08 +0000 (20:10 +0300)
* ggml : use 8-bit precision for Q4_1 intermediate results (ARM)

* ggml : optimize ggml_vec_dot_q4_1_q8_0() via vmalq_n_f32

56 ms/token with Q4_1 !

* ggml : AVX2 implementation of ggml_vec_dot_q4_1_q8_0 (#1051)

* gitignore : ignore ppl-*.txt files

---------

Co-authored-by: slaren <redacted>
.gitignore
ggml.c

index 631f2360ba1c31d5e4c6f1e1e39f22e4b223ed75..e52d479eeafa8e3c29892341dff70fb12df077f6 100644 (file)
@@ -1,11 +1,15 @@
 *.o
 *.a
+.DS_Store
+.build/
 .cache/
+.direnv/
+.envrc
+.swiftpm
+.venv
 .vs/
 .vscode/
-.DS_Store
 
-.build/
 build/
 build-em/
 build-debug/
@@ -30,12 +34,9 @@ models/*
 arm_neon.h
 compile_commands.json
 
-.envrc
-.direnv/
-
-.venv
 __pycache__
-.swiftpm
 
 zig-out/
 zig-cache/
+
+ppl-*.txt
diff --git a/ggml.c b/ggml.c
index 7728794743c711acbc0043746541fb3f76dd81a0..3b38eaad3673612fe62cbcd11c4b9588d944062c 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -550,6 +550,18 @@ inline static uint16_t vaddvq_u8(uint8x16_t v) {
         (uint16_t)vgetq_lane_u8(v, 14) + (uint16_t)vgetq_lane_u8(v, 15);
 }
 
+inline static int16_t vaddvq_s8(int8x16_t v) {
+    return
+        (int16_t)vgetq_lane_s8(v, 0)  + (int16_t)vgetq_lane_s8(v, 1)  +
+        (int16_t)vgetq_lane_s8(v, 2)  + (int16_t)vgetq_lane_s8(v, 3)  +
+        (int16_t)vgetq_lane_s8(v, 4)  + (int16_t)vgetq_lane_s8(v, 5)  +
+        (int16_t)vgetq_lane_s8(v, 6)  + (int16_t)vgetq_lane_s8(v, 7)  +
+        (int16_t)vgetq_lane_s8(v, 8)  + (int16_t)vgetq_lane_s8(v, 9)  +
+        (int16_t)vgetq_lane_s8(v, 10) + (int16_t)vgetq_lane_s8(v, 11) +
+        (int16_t)vgetq_lane_s8(v, 12) + (int16_t)vgetq_lane_s8(v, 13) +
+        (int16_t)vgetq_lane_s8(v, 14) + (int16_t)vgetq_lane_s8(v, 15);
+}
+
 inline static int32_t vaddvq_s16(int16x8_t v) {
     return
         (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
@@ -1535,9 +1547,8 @@ static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, in
     }
 }
 
-static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
-//static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
+static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy);
 
 static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
@@ -1552,8 +1563,8 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
         .dequantize_row_q         = dequantize_row_q4_1,
         .quantize_row_q           = quantize_row_q4_1,
         .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
-        .quantize_row_q_dot       = quantize_row_q4_1,
-        .vec_dot_q                = ggml_vec_dot_q4_1,
+        .quantize_row_q_dot       = quantize_row_q8_0,
+        .vec_dot_q                = ggml_vec_dot_q4_1_q8_0,
     },
     [GGML_TYPE_Q4_2] = {
         .dequantize_row_q         = dequantize_row_q4_2,
@@ -2170,189 +2181,6 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
     *s = sumf;
 }
 
-static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
-    const int nb = n / QK4_1;
-
-    const block_q4_1 * restrict x = vx;
-    const block_q4_1 * restrict y = vy;
-
-    float sumf = 0.0;
-
-#if defined(__AVX2__)
-    // Initialize accumulator with zeros
-    __m256 acc = _mm256_setzero_ps();
-    // Accumulator for constant offsets
-    float acc_offset = 0.0f;
-
-    // Main loop
-    for (int i = 0; i < nb; ++i) {
-        const float * d0 = &x[i].d;
-        const float * d1 = &y[i].d;
-
-        const float * m0 = &x[i].m;
-        const float * m1 = &y[i].m;
-
-        const __m256 d0v = _mm256_broadcast_ss( d0 );
-        const __m256 d1v = _mm256_broadcast_ss( d1 );
-        const __m256 m0v = _mm256_broadcast_ss( m0 );
-        const __m256 m1v = _mm256_broadcast_ss( m1 );
-
-        // Compute combined scale for the block
-        const __m256 scale_01 = _mm256_mul_ps( d0v, d1v );
-
-        // Compute cross scales for the block
-        const __m256 scale_0 = _mm256_mul_ps( d0v, m1v );
-        const __m256 scale_1 = _mm256_mul_ps( m0v, d1v );
-        const __m256 cross_scales = _mm256_blend_ps( scale_0, scale_1, 0xAA /* 0b10101010 */ );
-
-        // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
-        __m256i bx = bytesFromNibbles( x[i].qs );
-        __m256i by = bytesFromNibbles( y[i].qs );
-
-        // Now we have a vector with bytes in [ 0 .. 15 ] interval.
-
-        // Sign-extend first 16 signed bytes into int16_t
-        __m256i x16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( bx ) );
-        __m256i y16 = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
-        // Compute products of int16_t integers, add pairwise
-        __m256i i32 = _mm256_madd_epi16( x16, y16 );
-
-        // Sign-extend last 16 signed bytes into int16_t vectors
-        __m256i x16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( bx, 1 ) );
-        __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
-        // Accumulate products of int16_t integers
-        i32 = _mm256_add_epi32( i32, _mm256_madd_epi16( x16_h, y16_h ) );
-
-        // compute sums of unsigned bytes in bx, by in blocks of 8.
-        // This results in a layout like X100 0000 X200 0000 X300 0000 X400 0000,
-        // which we then interleave as X100 Y100 X200 Y200 X300 Y300 X400 Y400.
-        // so if we then cast to 8 singles, we get 8 floats like [ x0_7, y0_7, x8_15, y8_15, x16_23, y16_23, x24_31, y24_31 ]
-        __m256i xsumi = _mm256_sad_epu8( bx, _mm256_setzero_si256() );
-        __m256i ysumi = _mm256_sad_epu8( by, _mm256_setzero_si256() );
-        __m256i sumsi = _mm256_or_si256( xsumi, _mm256_slli_si256( ysumi, 4 ) );
-        __m256  sums  = _mm256_cvtepi32_ps( sumsi );
-
-        // Convert int32_t to float
-        __m256 p = _mm256_cvtepi32_ps( i32 );
-        // Apply the scale, and accumulate
-        // acc += d0*d1*x*y + d0*m1*x + d1*m0*y
-        acc = _mm256_fmadd_ps( scale_01, p, acc );
-        acc = _mm256_fmadd_ps( cross_scales, sums, acc );
-        // acc_offset += m0*m1 (for each entry in the block)
-        acc_offset += (*m0)*(*m1);
-    }
-
-    // Return horizontal sum of the acc vector
-    __m128 res = _mm256_extractf128_ps( acc, 1 );
-    res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
-    res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
-    res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
-
-    sumf = _mm_cvtss_f32( res ) + acc_offset * QK4_1;
-#elif defined(__ARM_NEON)
-    float sum00 = 0.0f;
-    float sum01 = 0.0f;
-    float sum10 = 0.0f;
-    float sum11 = 0.0f;
-
-    for (int i = 0; i < nb; i += 2) {
-        const block_q4_1 * restrict x0 = &x[i + 0];
-        const block_q4_1 * restrict y0 = &y[i + 0];
-        const block_q4_1 * restrict x1 = &x[i + 1];
-        const block_q4_1 * restrict y1 = &y[i + 1];
-
-        const uint8x16_t m4b = vdupq_n_u8(0xf);
-
-        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
-        const uint8x16_t v1_0 = vld1q_u8(y0->qs);
-        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
-        const uint8x16_t v1_1 = vld1q_u8(y1->qs);
-
-        // 4-bit -> 8-bit
-        const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
-        const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
-        const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
-        const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
-
-        const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
-        const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
-        const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
-        const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
-
-        sum00 += x0->m*y0->m;
-        sum01 += y0->m*x0->d*((uint16_t)vaddvq_u8(v0_0l) + (uint16_t)vaddvq_u8(v0_0h));
-        sum10 += x0->m*y0->d*((uint16_t)vaddvq_u8(v1_0l) + (uint16_t)vaddvq_u8(v1_0h));
-
-        sum00 += x1->m*y1->m;
-        sum01 += y1->m*x1->d*((uint16_t)vaddvq_u8(v0_1l) + (uint16_t)vaddvq_u8(v0_1h));
-        sum10 += x1->m*y1->d*((uint16_t)vaddvq_u8(v1_1l) + (uint16_t)vaddvq_u8(v1_1h));
-
-#if defined(__ARM_FEATURE_DOTPROD)
-        // dot product into int32x4_t
-        uint32x4_t p_0 = vdotq_u32(vdupq_n_u32(0), v0_0l, v1_0l);
-        uint32x4_t p_1 = vdotq_u32(vdupq_n_u32(0), v0_1l, v1_1l);
-
-        p_0 = vdotq_u32(p_0, v0_0h, v1_0h);
-        p_1 = vdotq_u32(p_1, v0_1h, v1_1h);
-
-        sum11 += x0->d*y0->d*vaddvq_u32(p_0);
-        sum11 += x1->d*y1->d*vaddvq_u32(p_1);
-#else
-        const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
-        const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
-        const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
-        const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
-
-        const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l));
-        const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l));
-        const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h));
-        const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h));
-
-        const uint16x8_t pl_0 = vaddq_u16(pl0l, pl0h);
-        const uint16x8_t ph_0 = vaddq_u16(ph0l, ph0h);
-
-        const uint16x8_t pl_1 = vaddq_u16(pl1l, pl1h);
-        const uint16x8_t ph_1 = vaddq_u16(ph1l, ph1h);
-
-        const uint16x8_t p_0 = vaddq_u16(pl_0, ph_0);
-        const uint16x8_t p_1 = vaddq_u16(pl_1, ph_1);
-
-        sum11 += x0->d*y0->d*vaddvq_u16(p_0);
-        sum11 += x1->d*y1->d*vaddvq_u16(p_1);
-#endif
-    }
-
-    sumf = QK4_1*sum00 + sum01 + sum10 + sum11;
-#else
-    // scalar
-    for (int i = 0; i < nb; i++) {
-        const float d0 = x[i].d;
-        const float d1 = y[i].d;
-
-        const float m0 = x[i].m;
-        const float m1 = y[i].m;
-
-        const uint8_t * restrict p0 = x[i].qs;
-        const uint8_t * restrict p1 = y[i].qs;
-
-        for (int j = 0; j < QK4_1/2; j++) {
-            const uint8_t v0 = p0[j];
-            const uint8_t v1 = p1[j];
-
-            const float f0 = d0*(v0 & 0xf) + m0;
-            const float f1 = d0*(v0 >> 4)  + m0;
-
-            const float f2 = d1*(v1 & 0xf) + m1;
-            const float f3 = d1*(v1 >> 4)  + m1;
-
-            sumf += f0*f2 + f1*f3;
-        }
-    }
-#endif
-
-    *s = sumf;
-}
-
 static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
     const int nb = n / QK8_0;
 
@@ -2549,6 +2377,175 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
     *s = sumf;
 }
 
+static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
+    const int nb = n / QK8_0;
+
+    assert(n % QK8_0 == 0);
+    assert(nb % 2 == 0);
+
+    const block_q4_1 * restrict x = vx;
+    const block_q8_0 * restrict y = vy;
+
+    float sumf = 0.0;
+
+    // TODO: add AVX / WASM SIMD / etc
+#if defined(__ARM_NEON)
+    float32x4_t sumv0 = vdupq_n_f32(0.0f);
+    float32x4_t sumv1 = vdupq_n_f32(0.0f);
+
+    for (int i = 0; i < nb; i += 2) {
+        const block_q4_1 * restrict x0 = &x[i + 0];
+        const block_q4_1 * restrict x1 = &x[i + 1];
+        const block_q8_0 * restrict y0 = &y[i + 0];
+        const block_q8_0 * restrict y1 = &y[i + 1];
+
+        const uint8x16_t m4b = vdupq_n_u8(0xf);
+
+        const uint8x16_t v0_0 = vld1q_u8(x0->qs);
+        const uint8x16_t v0_1 = vld1q_u8(x1->qs);
+
+        // 4-bit -> 8-bit
+        const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8  (v0_0, m4b));
+        const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
+        const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8  (v0_1, m4b));
+        const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
+
+        // load y
+        const int8x16_t v1_0l = vld1q_s8(y0->qs);
+        const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
+        const int8x16_t v1_1l = vld1q_s8(y1->qs);
+        const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
+
+        // interleave
+        const int8x16_t v1_0ls = vuzp1q_s8(v1_0l, v1_0h);
+        const int8x16_t v1_0hs = vuzp2q_s8(v1_0l, v1_0h);
+        const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h);
+        const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h);
+
+        const int16x8_t s0i = vaddq_s16(
+                        vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))),
+                        vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs))));
+
+        const int16x8_t s1i = vaddq_s16(
+                        vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))),
+                        vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs))));
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d);
+
+#if defined(__ARM_FEATURE_DOTPROD)
+        // dot product into int32x4_t
+        const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs);
+        const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs);
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
+#else
+        const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls));
+        const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls));
+        const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0hs));
+        const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0hs));
+
+        const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1ls));
+        const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1ls));
+        const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs));
+        const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1hs));
+
+        const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
+        const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
+        const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
+        const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
+
+        sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d);
+        sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d);
+#endif
+    }
+
+    sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#elif defined(__AVX2__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    // Main loop
+    for (int i = 0; i < nb; ++i) {
+        const float * d0 = &x[i].d;
+        const float * d1 = &y[i].d;
+        const float * m0 = &x[i].m;
+
+        const __m256 d0v = _mm256_broadcast_ss( d0 );
+        const __m256 d1v = _mm256_broadcast_ss( d1 );
+        const __m256 m0v = _mm256_broadcast_ss( m0 );
+
+        // Compute combined scales
+        const __m256 d0d1 = _mm256_mul_ps( d0v, d1v );
+        const __m256 d1m0 = _mm256_mul_ps( d1v, m0v );
+
+        // Load 16 bytes, and unpack 4 bit fields into bytes, making 32 bytes
+        const __m256i bx = bytesFromNibbles( x[i].qs );
+        const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs );
+
+        // Get absolute values of x vectors
+        const __m256i ax = _mm256_sign_epi8( bx, bx );
+
+        // Sign the values of the y vectors
+        const __m256i sy = _mm256_sign_epi8( by, bx );
+
+        // Perform multiplication and create 16-bit values
+        const __m256i dot = _mm256_maddubs_epi16( ax, sy );
+        const __m256i ones = _mm256_set1_epi16( 1 );
+        const __m256i xy_q = _mm256_madd_epi16( ones, dot );
+
+        // Convert to vector of 8 int32_t to 8 floats
+        const __m256 xy = _mm256_cvtepi32_ps( xy_q );
+
+        // Accumulate d0*d1*x*y
+        acc = _mm256_fmadd_ps( d0d1, xy, acc );
+
+        // Compute sum of y values
+        const __m256i y16_l = _mm256_cvtepi8_epi16( _mm256_castsi256_si128( by ) );
+        const __m256i y16_h = _mm256_cvtepi8_epi16( _mm256_extracti128_si256( by, 1 ) );
+        const __m256i ysumi = _mm256_madd_epi16( _mm256_add_epi16(y16_l, y16_h), ones );
+        const __m256 ysum = _mm256_cvtepi32_ps( ysumi );
+
+        // Accumulate d1*m0*y
+        acc = _mm256_fmadd_ps( d1m0, ysum, acc );
+    }
+
+    // Return horizontal sum of the acc vector
+    __m128 res = _mm256_extractf128_ps( acc, 1 );
+    res = _mm_add_ps( res, _mm256_castps256_ps128( acc ) );
+    res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
+    res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
+
+    sumf = _mm_cvtss_f32( res );
+#else
+    // scalar
+    for (int i = 0; i < nb; i++) {
+        const float d0 = x[i].d;
+        const float m0 = x[i].m;
+        const float d1 = y[i].d;
+
+        const uint8_t * restrict p0 = x[i].qs;
+        const  int8_t * restrict p1 = y[i].qs;
+
+        // TODO: this is very slow ..
+        for (int j = 0; j < QK8_0/2; j++) {
+            const uint8_t v0 = p0[j];
+
+            const float f0 = d0*(v0 & 0xf) + m0;
+            const float f1 = d0*(v0 >> 4)  + m0;
+
+            const float f2 = d1*p1[2*j + 0];
+            const float f3 = d1*p1[2*j + 1];
+
+            sumf += f0*f2 + f1*f3;
+        }
+    }
+#endif
+
+    *s = sumf;
+}
+
 static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
     const int nb = n / QK8_0;