]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml-cpu: arm64: q6_K repack gemm and gemv (and generic) implementations (dotprod...
authorAlberto Cabrera Pérez <redacted>
Tue, 10 Feb 2026 10:47:45 +0000 (10:47 +0000)
committerGitHub <redacted>
Tue, 10 Feb 2026 10:47:45 +0000 (10:47 +0000)
* First working version of GEMM and GEMV

* interleave loads and compute

* Clang-format

* Added missing fallback. Removed tested TODO.

* Swap M and N to be consistent with the repack template convention

ggml/src/ggml-cpu/arch-fallback.h
ggml/src/ggml-cpu/arch/arm/repack.cpp
ggml/src/ggml-cpu/repack.cpp
ggml/src/ggml-cpu/repack.h

index 427c1146e4664480c6e713813b748dafaf73cd45..c6eb75b23007a7b2ff4d759e5768bc95d48cfdf3 100644 (file)
@@ -43,6 +43,7 @@
 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
 #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
+#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
 #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
@@ -55,7 +56,8 @@
 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
 #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
-#    define ggml_gemm_q6_K_8x8_q8_K_generic   ggml_gemm_q6_K_8x8_q8_K
+#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
+#define ggml_gemm_q6_K_8x8_q8_K_generic   ggml_gemm_q6_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
@@ -76,6 +78,7 @@
 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
 #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
+#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
 #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
@@ -84,6 +87,7 @@
 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
 #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
+#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
 #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
 #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
+#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
 #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
 #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
+#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
 #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
 #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
+#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
 #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
 #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
+#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
 #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
 #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
+#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
 #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
 #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
+#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
 #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
 #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
+#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
 #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
 #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
+#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
 #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
 #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
 #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
+#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
 #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
 #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
 #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
+#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
 #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
index 99bb70274c56c92478cfdef2d987904009008ae0..fd05c609f7eab6ba558bb68826a5e53c6bd1cdaa 100644 (file)
@@ -1072,6 +1072,195 @@ void ggml_gemv_q5_K_8x8_q8_K(int                        n,
     ggml_gemv_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
 }
 
+void ggml_gemv_q6_K_8x4_q8_K(int                        n,
+                             float * GGML_RESTRICT      s,
+                             size_t                     bs,
+                             const void * GGML_RESTRICT vx,
+                             const void * GGML_RESTRICT vy,
+                             int                        nr,
+                             int                        nc) {
+    constexpr int qk = QK_K;
+    const int     nb = n / qk;
+
+    constexpr int ncols_interleaved = 8;
+    constexpr int blocklen          = 4;
+
+    assert(n % qk == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    constexpr int    col_groups = ncols_interleaved / 4;
+    const uint8x16_t m4b        = vdupq_n_u8(0x0f);
+    const uint8x16_t mask_lo    = vdupq_n_u8(0x03);
+    const uint8x16_t mask_hi    = vdupq_n_u8(0x30);
+
+    // 1x8 tile = 2 x 4
+    float32x4_t acc_f32[2];
+
+    const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
+
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
+
+        for (int i = 0; i < col_groups; i++) {
+            acc_f32[i] = vdupq_n_f32(0);
+        }
+
+        for (int b = 0; b < nb; b++) {
+            float32x4_t q6_d_0     = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d));      // d0 d1 d2 d3
+            float32x4_t q6_d_1     = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4));  // d4 d5 d6 d7
+            float32x4_t q8_d       = vdupq_n_f32(q8_ptr[b].d);
+            float32x4_t sb_scale_0 = vmulq_f32(q6_d_0, q8_d);
+            float32x4_t sb_scale_1 = vmulq_f32(q6_d_1, q8_d);
+
+            int32x4_t acc[col_groups];
+            for (int i = 0; i < col_groups; i++) {
+                acc[i] = vdupq_n_s32(0);
+            }
+
+            // Load all 16 scales once and widen to int16 (Q6_K has 16 scales per block)
+            // Reused for bias and dequantization later
+            int16_t q6_scales[16 * 8];
+            for (int i = 0; i < 16; i++) {
+                int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
+                vst1q_s16(q6_scales + i * 8, scales);
+            }
+
+            // Compute bias per column using q8 bsums and preloaded scales to skip the -32 shift
+            int32x4_t bias_lo = vdupq_n_s32(0);
+            int32x4_t bias_hi = vdupq_n_s32(0);
+
+            // Load bsums in chunks of 4 to process with vectorized operations
+            for (int i = 0; i < 16; i += 4) {
+                int16x4_t bsums_vec   = vld1_s16(q8_ptr[b].bsums + i);
+                int16x4_t scales_lo_0 = vld1_s16(q6_scales + (i + 0) * 8);
+                int16x4_t scales_hi_0 = vld1_s16(q6_scales + (i + 0) * 8 + 4);
+                int16x4_t scales_lo_1 = vld1_s16(q6_scales + (i + 1) * 8);
+                int16x4_t scales_hi_1 = vld1_s16(q6_scales + (i + 1) * 8 + 4);
+                int16x4_t scales_lo_2 = vld1_s16(q6_scales + (i + 2) * 8);
+                int16x4_t scales_hi_2 = vld1_s16(q6_scales + (i + 2) * 8 + 4);
+                int16x4_t scales_lo_3 = vld1_s16(q6_scales + (i + 3) * 8);
+                int16x4_t scales_hi_3 = vld1_s16(q6_scales + (i + 3) * 8 + 4);
+
+                bias_lo = vmlal_lane_s16(bias_lo, scales_lo_0, bsums_vec, 0);
+                bias_hi = vmlal_lane_s16(bias_hi, scales_hi_0, bsums_vec, 0);
+                bias_lo = vmlal_lane_s16(bias_lo, scales_lo_1, bsums_vec, 1);
+                bias_hi = vmlal_lane_s16(bias_hi, scales_hi_1, bsums_vec, 1);
+                bias_lo = vmlal_lane_s16(bias_lo, scales_lo_2, bsums_vec, 2);
+                bias_hi = vmlal_lane_s16(bias_hi, scales_hi_2, bsums_vec, 2);
+                bias_lo = vmlal_lane_s16(bias_lo, scales_lo_3, bsums_vec, 3);
+                bias_hi = vmlal_lane_s16(bias_hi, scales_hi_3, bsums_vec, 3);
+            }
+            bias_lo = vshlq_n_s32(bias_lo, 5);
+            bias_hi = vshlq_n_s32(bias_hi, 5);
+
+            // Process two 128-value halves per superblock
+            for (int half = 0; half < 2; half++) {
+                const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
+                const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
+
+                // A subblock (sb) is a set of weights that share the scale
+                // Since q6_K scales are per 16 elements
+                // num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves)
+                for (int sb = 0; sb < QK_K / 64; sb++) {
+                    const int8_t * q8_base_l = q8_ptr[b].qs + half * 128 + sb * 16;
+                    const int8_t * q8_base_h = q8_base_l + 64;
+
+                    // Load and duplicate q8 values (each register covers four interleaved columns of q6)
+                    int8x16_t q8_l[4];
+                    int8x16_t q8_h[4];
+                    for (int i = 0; i < 4; i++) {
+                        q8_l[i] = (int8x16_t) vld1q_dup_s32((const int32_t *) (q8_base_l + i * 4));
+                        q8_h[i] = (int8x16_t) vld1q_dup_s32((const int32_t *) (q8_base_h + i * 4));
+                    }
+
+                    const int ql_off_base = sb * QK_K / 2;
+                    const int qh_off_base = ql_off_base & 255;  // wraps after 256 bytes
+
+                    // Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1)
+                    uint8x16x4_t q6_ql_0 = vld1q_u8_x4(ql_base + ql_off_base);
+                    uint8x16x4_t q6_ql_1 = vld1q_u8_x4(ql_base + ql_off_base + 64);
+                    uint8x16x4_t q6_qh_0 = vld1q_u8_x4(qh_base + qh_off_base);
+                    uint8x16x4_t q6_qh_1 = vld1q_u8_x4(qh_base + qh_off_base + 64);
+
+                    // Adjust qh for subblocks 2 and 3 (shift right by 2)
+                    if (sb > 1) {
+                        q6_qh_0.val[0] = vshrq_n_u8(q6_qh_0.val[0], 2);
+                        q6_qh_0.val[1] = vshrq_n_u8(q6_qh_0.val[1], 2);
+                        q6_qh_0.val[2] = vshrq_n_u8(q6_qh_0.val[2], 2);
+                        q6_qh_0.val[3] = vshrq_n_u8(q6_qh_0.val[3], 2);
+                        q6_qh_1.val[0] = vshrq_n_u8(q6_qh_1.val[0], 2);
+                        q6_qh_1.val[1] = vshrq_n_u8(q6_qh_1.val[1], 2);
+                        q6_qh_1.val[2] = vshrq_n_u8(q6_qh_1.val[2], 2);
+                        q6_qh_1.val[3] = vshrq_n_u8(q6_qh_1.val[3], 2);
+                    }
+
+                    const uint8x16_t q6_ql[8] = { q6_ql_0.val[0], q6_ql_0.val[1], q6_ql_0.val[2], q6_ql_0.val[3],
+                                                  q6_ql_1.val[0], q6_ql_1.val[1], q6_ql_1.val[2], q6_ql_1.val[3] };
+                    const uint8x16_t q6_qh[8] = { q6_qh_0.val[0], q6_qh_0.val[1], q6_qh_0.val[2], q6_qh_0.val[3],
+                                                  q6_qh_1.val[0], q6_qh_1.val[1], q6_qh_1.val[2], q6_qh_1.val[3] };
+
+                    // Process column groups (0-3, 4-7)
+                    for (int g = 0; g < col_groups; g++) {
+                        int32x4_t sb_acc_l = vdupq_n_s32(0);
+                        int32x4_t sb_acc_h = vdupq_n_s32(0);
+
+                        for (int chunk = 0; chunk < 4; chunk++) {
+                            const int idx = chunk * 2 + g;
+
+                            const uint8x16_t q6_qs_l = q6_ql[idx];
+                            const uint8x16_t q6_qs_h = q6_qh[idx];
+
+                            // Extract high 2 bits for upper nibble reconstruction
+                            const uint8x16_t q6_qs_hh = vandq_u8(q6_qs_h, mask_hi);
+
+                            // q6 = (low4 | high2<<4), without -32 bias (handled via bsums)
+                            const int8x16_t q6_l =
+                                vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_l, m4b), vandq_u8(q6_qs_h, mask_lo), 4));
+                            const int8x16_t q6_h = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_l, 4), q6_qs_hh));
+
+                            sb_acc_l = vdotq_s32(sb_acc_l, q6_l, q8_l[chunk]);
+                            sb_acc_h = vdotq_s32(sb_acc_h, q6_h, q8_h[chunk]);
+                        }
+
+                        const int scale_idx_l = half * 8 + sb;
+                        const int scale_idx_h = half * 8 + sb + 4;
+
+                        const int32x4_t scale_vec_l = vmovl_s16(vld1_s16(q6_scales + scale_idx_l * 8 + g * 4));
+                        const int32x4_t scale_vec_h = vmovl_s16(vld1_s16(q6_scales + scale_idx_h * 8 + g * 4));
+
+                        acc[g] = vmlaq_s32(acc[g], sb_acc_l, scale_vec_l);
+                        acc[g] = vmlaq_s32(acc[g], sb_acc_h, scale_vec_h);
+                    }
+                }
+            }  // for half
+
+            // Bias correction
+            acc[0] = vsubq_s32(acc[0], bias_lo);
+            acc[1] = vsubq_s32(acc[1], bias_hi);
+
+            // Apply superblock scale (no mins for q6_K)
+            // acc[g] has [c0, c1, c2, c3]
+            float32x4_t w_0123 = vmulq_f32(vcvtq_f32_s32(acc[0]), sb_scale_0);
+            float32x4_t w_4567 = vmulq_f32(vcvtq_f32_s32(acc[1]), sb_scale_1);
+
+            acc_f32[0] = vaddq_f32(acc_f32[0], w_0123);
+            acc_f32[1] = vaddq_f32(acc_f32[1], w_4567);
+        }  // for b
+
+        int base = x * ncols_interleaved;
+        vst1q_f32(s + base, acc_f32[0]);
+        vst1q_f32(s + base + 4, acc_f32[1]);
+    }  // for x
+    return;
+#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    ggml_gemv_q6_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
+}
+
 void ggml_gemv_q6_K_8x8_q8_K(int                        n,
                              float * GGML_RESTRICT      s,
                              size_t                     bs,
@@ -1177,15 +1366,14 @@ void ggml_gemv_q6_K_8x8_q8_K(int                        n,
                         q8_h[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base_h + i * 8));
                     }
 
-                    // TODO: Test other qh repack patterns to reduce loads
                     const int ql_off_base = sb * QK_K / 2;
                     const int qh_off_base = ql_off_base & 255;  // wraps after 256 bytes
 
                     // Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1)
-                    ggml_uint8x16x4_t q6_ql_0 = ggml_vld1q_u8_x4(ql_base + ql_off_base);
-                    ggml_uint8x16x4_t q6_ql_1 = ggml_vld1q_u8_x4(ql_base + ql_off_base + 64);
-                    ggml_uint8x16x4_t q6_qh_0 = ggml_vld1q_u8_x4(qh_base + qh_off_base);
-                    ggml_uint8x16x4_t q6_qh_1 = ggml_vld1q_u8_x4(qh_base + qh_off_base + 64);
+                    uint8x16x4_t q6_ql_0 = vld1q_u8_x4(ql_base + ql_off_base);
+                    uint8x16x4_t q6_ql_1 = vld1q_u8_x4(ql_base + ql_off_base + 64);
+                    uint8x16x4_t q6_qh_0 = vld1q_u8_x4(qh_base + qh_off_base);
+                    uint8x16x4_t q6_qh_1 = vld1q_u8_x4(qh_base + qh_off_base + 64);
 
                     // Adjust qh for subblocks 2 and 3 (shift right by 2)
                     if (sb > 1) {
@@ -3474,6 +3662,208 @@ void ggml_gemm_q5_K_8x8_q8_K(int                        n,
     ggml_gemm_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
 }
 
+void ggml_gemm_q6_K_8x4_q8_K(int                        n,
+                             float * GGML_RESTRICT      s,
+                             size_t                     bs,
+                             const void * GGML_RESTRICT vx,
+                             const void * GGML_RESTRICT vy,
+                             int                        nr,
+                             int                        nc) {
+    constexpr int qk = QK_K;
+    const int     nb = n / qk;
+
+    constexpr int ncols_interleaved = 8;
+    constexpr int blocklen          = 4;
+
+    assert(n % qk == 0);
+    assert(nr % 4 == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    constexpr int    q8_k_blocklen = 4;
+    constexpr int    col_groups    = ncols_interleaved / 4;
+    constexpr int    acc_size      = q8_k_blocklen * col_groups;  // 4 rows, 2 column groups
+    const uint8x16_t m4b           = vdupq_n_u8(0x0f);
+    const uint8x16_t mask_lo       = vdupq_n_u8(0x03);
+    const uint8x16_t mask_hi       = vdupq_n_u8(0x30);
+    const int8x16_t  m32s          = vdupq_n_s8(32);
+
+    float32x4_t acc_f32[acc_size];
+
+    for (int y = 0; y < nr / q8_k_blocklen; y++) {
+        const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
+
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
+
+            for (int i = 0; i < acc_size; i++) {
+                acc_f32[i] = vdupq_n_f32(0);
+            }
+
+            for (int b = 0; b < nb; b++) {
+                float32x4_t q6_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d));
+                float32x4_t q6_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4));
+                float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d);
+
+                float32x4_t sbd_scale_0123[q8_k_blocklen];
+                float32x4_t sbd_scale_4567[q8_k_blocklen];
+
+                sbd_scale_0123[0] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 0);
+                sbd_scale_4567[0] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 0);
+                sbd_scale_0123[1] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 1);
+                sbd_scale_4567[1] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 1);
+                sbd_scale_0123[2] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 2);
+                sbd_scale_4567[2] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 2);
+                sbd_scale_0123[3] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 3);
+                sbd_scale_4567[3] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 3);
+
+                int32x4_t acc_s32[acc_size];
+                for (int i = 0; i < acc_size; i++) {
+                    acc_s32[i] = vdupq_n_s32(0);
+                }
+
+                int16_t q6_scales[8 * 16];
+                for (int i = 0; i < 16; i++) {
+                    int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
+                    vst1q_s16(q6_scales + i * 8, scales);
+                }
+
+                for (int half = 0; half < 2; half++) {
+                    const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
+                    const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
+
+                    for (int sb = 0; sb < QK_K / 64; sb++) {
+                        int32x4_t acc_lo[acc_size];
+                        int32x4_t acc_hi[acc_size];
+                        for (int i = 0; i < acc_size; i++) {
+                            acc_lo[i] = vdupq_n_s32(0);
+                            acc_hi[i] = vdupq_n_s32(0);
+                        }
+
+                        const int8_t * q8_base_l = q8_ptr[b].qs + half * 512 + sb * 64;
+                        const int8_t * q8_base_h = q8_ptr[b].qs + half * 512 + 256 + sb * 64;
+
+                        // 4 rows * 16 elements per scale
+                        // 4 reads of 16 bytes each
+                        constexpr int reads_per_sb = 4;
+                        int8x16_t     q8_l[reads_per_sb];
+                        int8x16_t     q8_h[reads_per_sb];
+                        for (int k = 0; k < reads_per_sb; k++) {
+                            q8_l[k] = vld1q_s8(q8_base_l + 16 * k);
+                            q8_h[k] = vld1q_s8(q8_base_h + 16 * k);
+                        }
+
+                        const int ql_off_base = sb * QK_K / 2;
+                        const int qh_off_base = ql_off_base & 255;
+
+                        uint8x16_t q6_ql_0123[reads_per_sb];
+                        uint8x16_t q6_ql_4567[reads_per_sb];
+                        uint8x16_t q6_qh_0123[reads_per_sb];
+                        uint8x16_t q6_qh_4567[reads_per_sb];
+
+                        for (int k = 0; k < reads_per_sb; k++) {
+                            q6_ql_0123[k] = vld1q_u8(ql_base + ql_off_base + k * 32);
+                            q6_ql_4567[k] = vld1q_u8(ql_base + ql_off_base + k * 32 + 16);
+                            q6_qh_0123[k] = vld1q_u8(qh_base + qh_off_base + k * 32);
+                            q6_qh_4567[k] = vld1q_u8(qh_base + qh_off_base + k * 32 + 16);
+                        }
+
+                        if (sb > 1) {
+                            for (int k = 0; k < reads_per_sb; k++) {
+                                q6_qh_0123[k] = vshrq_n_u8(q6_qh_0123[k], 2);
+                                q6_qh_4567[k] = vshrq_n_u8(q6_qh_4567[k], 2);
+                            }
+                        }
+
+                        for (int k = 0; k < reads_per_sb; k++) {
+                            // q = (ql | qh) - 32
+                            const uint8x16_t hbit_lo_0123 = vandq_u8(q6_qh_0123[k], mask_lo);
+                            const uint8x16_t hbit_hi_0123 = vandq_u8(q6_qh_0123[k], mask_hi);
+                            const uint8x16_t hbit_lo_4567 = vandq_u8(q6_qh_4567[k], mask_lo);
+                            const uint8x16_t hbit_hi_4567 = vandq_u8(q6_qh_4567[k], mask_hi);
+
+                            const int8x16_t q6_0123_lo = vsubq_s8(
+                                vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_0123[k], m4b), hbit_lo_0123, 4)), m32s);
+                            const int8x16_t q6_0123_hi = vsubq_s8(
+                                vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_0123[k], 4), hbit_hi_0123)), m32s);
+
+                            acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q6_0123_lo, q8_l[k], 0);  //  0..3  r0 c0123
+                            acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q6_0123_lo, q8_l[k], 1);  //  0..3  r1 c0123
+                            acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q6_0123_lo, q8_l[k], 2);  //  0..3  r2 c0123
+                            acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q6_0123_lo, q8_l[k], 3);  //  0..3  r3 c0123
+
+                            acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q6_0123_hi, q8_h[k], 0);  // 64..67 r0 c0123
+                            acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q6_0123_hi, q8_h[k], 1);  // 64..67 r1 c0123
+                            acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q6_0123_hi, q8_h[k], 2);  // 64..67 r2 c0123
+                            acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q6_0123_hi, q8_h[k], 3);  // 64..67 r3 c0123
+
+                            const int8x16_t q6_4567_lo = vsubq_s8(
+                                vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_4567[k], m4b), hbit_lo_4567, 4)), m32s);
+                            const int8x16_t q6_4567_hi = vsubq_s8(
+                                vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_4567[k], 4), hbit_hi_4567)), m32s);
+
+                            acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q6_4567_lo, q8_l[k], 0);  //  0..3  r0 c4567
+                            acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q6_4567_lo, q8_l[k], 1);  //  0..3  r1 c4567
+                            acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q6_4567_lo, q8_l[k], 2);  //  0..3  r2 c4567
+                            acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q6_4567_lo, q8_l[k], 3);  //  0..3  r3 c4567
+
+                            acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q6_4567_hi, q8_h[k], 0);  // 64..67 r0 c4567
+                            acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q6_4567_hi, q8_h[k], 1);  // 64..67 r1 c4567
+                            acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q6_4567_hi, q8_h[k], 2);  // 64..67 r2 c4567
+                            acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q6_4567_hi, q8_h[k], 3);  // 64..67 r3 c4567
+                        }
+
+                        // Scale and bias
+                        const int scale_idx_l = half * 8 + sb;
+                        const int scale_idx_h = half * 8 + sb + 4;
+
+                        for (int g = 0; g < col_groups; g++) {
+                            const int16x4_t scales_l16  = vld1_s16(q6_scales + scale_idx_l * 8 + g * 4);
+                            const int16x4_t scales_h16  = vld1_s16(q6_scales + scale_idx_h * 8 + g * 4);
+                            const int32x4_t scale_vec_l = vmovl_s16(scales_l16);
+                            const int32x4_t scale_vec_h = vmovl_s16(scales_h16);
+                            const int       acc_offset  = g * q8_k_blocklen;
+
+                            for (int row = 0; row < q8_k_blocklen; row++) {
+                                const int idx = row * 2 + g;
+                                acc_s32[idx]  = vmlaq_s32(acc_s32[idx], acc_lo[acc_offset + row], scale_vec_l);
+                                acc_s32[idx]  = vmlaq_s32(acc_s32[idx], acc_hi[acc_offset + row], scale_vec_h);
+                            }
+                        }
+                    }
+                }
+
+                // Finally we apply the superblock scales
+                for (int row = 0; row < q8_k_blocklen; row++) {
+                    const int       idx0     = 2 * row;
+                    const int       idx1     = 2 * row + 1;
+                    const int32x4_t acc_0123 = acc_s32[idx0];
+                    const int32x4_t acc_4567 = acc_s32[idx1];
+
+                    acc_f32[idx0] = vmlaq_f32(acc_f32[idx0], vcvtq_f32_s32(acc_0123), sbd_scale_0123[row]);
+                    acc_f32[idx1] = vmlaq_f32(acc_f32[idx1], vcvtq_f32_s32(acc_4567), sbd_scale_4567[row]);
+                }
+            }  // for b
+
+            for (int i = 0; i < q8_k_blocklen; i++) {
+                int row = y * q8_k_blocklen + i;
+                for (int j = 0; j < 2; j++) {
+                    int col    = x * ncols_interleaved + j * 4;
+                    int offset = row * bs + col;
+                    vst1q_f32(s + offset, acc_f32[2 * i + j]);
+                }
+            }
+        }  // for x
+    }  // for y
+    return;
+#endif  // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    ggml_gemm_q6_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
+}
+
 void ggml_gemm_q6_K_8x8_q8_K(int                        n,
                              float * GGML_RESTRICT      s,
                              size_t                     bs,
index 24e8ab4618258deb49b2736e1b36c74dcf057704..4cb7cdeb07ba308b043b1d2460239f5a8463f961 100644 (file)
@@ -256,6 +256,200 @@ template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTR
     ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row);
 }
 
+template <int M, int N>
+static void ggml_gemv_q6_K_NxM_q8_K_generic_impl(int                        n,
+                                                 float * GGML_RESTRICT      s,
+                                                 size_t                     bs,
+                                                 const void * GGML_RESTRICT vx,
+                                                 const void * GGML_RESTRICT vy,
+                                                 int                        nr,
+                                                 int                        nc) {
+    constexpr int blocklen          = M;
+    constexpr int ncols_interleaved = N;
+    const int     qk                = QK_K;
+    const int     nb                = n / qk;
+    const int     blocks_per_half   = 64 / blocklen;
+
+    assert(n % qk == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(bs);
+    UNUSED(nr);
+
+    float sumf[8];
+
+    const block_q8_K * a_ptr = (const block_q8_K *) vy;
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb);
+
+        for (int j = 0; j < ncols_interleaved; j++) {
+            sumf[j] = 0.0f;
+        }
+
+        for (int l = 0; l < nb; l++) {
+            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                const int base_l = (k / blocks_per_half) * 128 + (k % blocks_per_half) * blocklen;
+                const int base_h = base_l + 64;
+
+                const int scale_idx_l = base_l / 16;
+                const int scale_idx_h = base_h / 16;
+
+                const int qh_shift_l = ((base_l % 128) / 32) * 2;
+                const int qh_shift_h = ((base_h % 128) / 32) * 2;
+
+                const int qh_half_l = (base_l / 128) * 32;
+                const int qh_half_h = (base_h / 128) * 32;
+
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    const int8_t scale_l = b_ptr[l].scales[scale_idx_l * ncols_interleaved + j];
+                    const int8_t scale_h = b_ptr[l].scales[scale_idx_h * ncols_interleaved + j];
+
+                    int sumi_l = 0;
+                    int sumi_h = 0;
+
+                    for (int i = 0; i < blocklen; i++) {
+                        const int ql_pos = k * ncols_interleaved * blocklen + j * blocklen + i;
+                        const int l_4    = b_ptr[l].ql[ql_pos] & 0xF;
+                        const int hi_4   = (b_ptr[l].ql[ql_pos] >> 4) & 0xF;
+
+                        const int qh_idx_l    = qh_half_l + ((base_l + i) % 32);
+                        const int qh_chunk_l  = qh_idx_l / blocklen;
+                        const int qh_pos_l    = qh_idx_l % blocklen;
+                        const int qh_offset_l = qh_chunk_l * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_l;
+                        const int hi_2_l      = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3;
+
+                        const int qh_idx_h    = qh_half_h + ((base_h + i) % 32);
+                        const int qh_chunk_h  = qh_idx_h / blocklen;
+                        const int qh_pos_h    = qh_idx_h % blocklen;
+                        const int qh_offset_h = qh_chunk_h * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_h;
+                        const int hi_2_h      = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3;
+
+                        const int q_l = ((hi_2_l << 4) | l_4) - 32;
+                        const int q_h = ((hi_2_h << 4) | hi_4) - 32;
+
+                        const int8_t a_l = a_ptr[l].qs[base_l + i];
+                        const int8_t a_h = a_ptr[l].qs[base_h + i];
+
+                        sumi_l += q_l * a_l;
+                        sumi_h += q_h * a_h;
+                    }
+
+                    sumf[j] +=
+                        (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
+                }
+            }
+        }
+
+        for (int j = 0; j < ncols_interleaved; j++) {
+            s[x * ncols_interleaved + j] = sumf[j];
+        }
+    }
+}
+
+template <int M, int N>
+static void ggml_gemm_q6_K_NxM_q8_K_generic_impl(int                        n,
+                                                 float * GGML_RESTRICT      s,
+                                                 size_t                     bs,
+                                                 const void * GGML_RESTRICT vx,
+                                                 const void * GGML_RESTRICT vy,
+                                                 int                        nr,
+                                                 int                        nc) {
+    constexpr int blocklen          = M;
+    constexpr int ncols_interleaved = N;
+    const int     qk                = QK_K;
+    const int     nb                = n / qk;
+    const int     blocks_per_half   = 64 / blocklen;
+    const int     q8_half_stride    = 512;
+    const int     q8_low_high_step  = 256;
+
+    assert(n % qk == 0);
+    assert(nr % 4 == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(bs);
+
+    float sumf[4][8];
+
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb);
+
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumf[m][j] = 0.0f;
+                }
+            }
+
+            for (int l = 0; l < nb; l++) {
+                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                    const int base_l = (k / blocks_per_half) * 128 + (k % blocks_per_half) * blocklen;
+                    const int base_h = base_l + 64;
+
+                    const int scale_idx_l = base_l / 16;
+                    const int scale_idx_h = base_h / 16;
+
+                    const int qh_shift_l = ((base_l % 128) / 32) * 2;
+                    const int qh_shift_h = ((base_h % 128) / 32) * 2;
+
+                    const int qh_half_l = (base_l / 128) * 32;
+                    const int qh_half_h = (base_h / 128) * 32;
+
+                    const int q8_base = (k / blocks_per_half) * q8_half_stride + (k % blocks_per_half) * (blocklen * 4);
+
+                    for (int m = 0; m < 4; m++) {
+                        for (int j = 0; j < ncols_interleaved; j++) {
+                            const int8_t scale_l = b_ptr[l].scales[scale_idx_l * ncols_interleaved + j];
+                            const int8_t scale_h = b_ptr[l].scales[scale_idx_h * ncols_interleaved + j];
+
+                            int sumi_l = 0;
+                            int sumi_h = 0;
+
+                            for (int i = 0; i < blocklen; i++) {
+                                const int ql_pos = k * ncols_interleaved * blocklen + j * blocklen + i;
+                                const int l_4    = b_ptr[l].ql[ql_pos] & 0xF;
+                                const int hi_4   = (b_ptr[l].ql[ql_pos] >> 4) & 0xF;
+
+                                const int qh_idx_l   = qh_half_l + ((base_l + i) % 32);
+                                const int qh_chunk_l = qh_idx_l / blocklen;
+                                const int qh_pos_l   = qh_idx_l % blocklen;
+                                const int qh_offset_l =
+                                    qh_chunk_l * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_l;
+                                const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3;
+
+                                const int qh_idx_h   = qh_half_h + ((base_h + i) % 32);
+                                const int qh_chunk_h = qh_idx_h / blocklen;
+                                const int qh_pos_h   = qh_idx_h % blocklen;
+                                const int qh_offset_h =
+                                    qh_chunk_h * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_h;
+                                const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3;
+
+                                const int q_l = ((hi_2_l << 4) | l_4) - 32;
+                                const int q_h = ((hi_2_h << 4) | hi_4) - 32;
+
+                                const int8_t q8_l = a_ptr[l].qs[q8_base + m * blocklen + i];
+                                const int8_t q8_h = a_ptr[l].qs[q8_base + m * blocklen + i + q8_low_high_step];
+
+                                sumi_l += q_l * q8_l;
+                                sumi_h += q_h * q8_h;
+                            }
+
+                            sumf[m][j] += (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) *
+                                          a_ptr[l].d[m];
+                        }
+                    }
+                }
+            }
+
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+                }
+            }
+        }
+    }
+}
+
 extern "C" {
 
 void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@@ -704,94 +898,12 @@ void ggml_gemv_q5_K_8x8_q8_K_generic(int                        n,
 }
 
 
-void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
-    constexpr int qk = QK_K;
-    const int nb = n / qk;
-    const int ncols_interleaved = 8;
-    const int blocklen = 8;
-
-    assert(n % qk == 0);
-    assert(nc % ncols_interleaved == 0);
-
-    UNUSED(bs);
-    UNUSED(nr);
-
-    float sumf[8];
-
-    const block_q8_K * a_ptr = (const block_q8_K *) vy;
-    for (int x = 0; x < nc / ncols_interleaved; x++) {
-        const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb);
-
-        for (int j = 0; j < ncols_interleaved; j++) {
-            sumf[j] = 0.0f;
-        }
-
-        for (int l = 0; l < nb; l++) {
-
-
-            for (int k = 0; k < 16; k++) {
-                // k = 0.. 7 weights 0-63 low, 64-127 high
-                // k = 8..15 weights 128-191 low, 192-255 high
-                const int base_l = (k / 8) * 128 + (k % 8) * 8;
-                const int base_h = base_l + 64;
-
-                const int scale_idx_l = base_l / 16;
-                const int scale_idx_h = base_h / 16;
-
-                // Bit shift cycles 0,2,4,6 for each 32-value group within a 128-value half
-                const int qh_shift_l = ((base_l % 128) / 32) * 2;
-                const int qh_shift_h = ((base_h % 128) / 32) * 2;
-
-                // qh_half: offset to the correct 32-byte half (0 or 32)
-                const int qh_half_l = (base_l / 128) * 32;
-                const int qh_half_h = (base_h / 128) * 32;
-
-                for (int j = 0; j < ncols_interleaved; j++) {
-                    // Interleaved scales
-                    const int8_t scale_l = b_ptr[l].scales[scale_idx_l * 8 + j];
-                    const int8_t scale_h = b_ptr[l].scales[scale_idx_h * 8 + j];
-
-                    int sumi_l = 0;
-                    int sumi_h = 0;
-
-                    for (int i = 0; i < blocklen; i++) {
-                        const int ql_pos = k * 64 + j * 8 + i;
-                        const int l_4    = b_ptr[l].ql[ql_pos] & 0xF;
-                        const int hi_4   = (b_ptr[l].ql[ql_pos] >> 4) & 0xF;
-
-                        // qh indexing with 8-byte interleaving (like q5_K)
-                        const int qh_byte_l   = qh_half_l + ((base_l + i) % 32);
-                        const int qh_chunk_l  = qh_byte_l / 8;
-                        const int qh_pos_l    = qh_byte_l % 8;
-                        const int qh_offset_l = qh_chunk_l * 64 + j * 8 + qh_pos_l;
-                        const int hi_2_l      = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3;
-
-                        const int qh_byte_h   = qh_half_h + ((base_h + i) % 32);
-                        const int qh_chunk_h  = qh_byte_h / 8;
-                        const int qh_pos_h    = qh_byte_h % 8;
-                        const int qh_offset_h = qh_chunk_h * 64 + j * 8 + qh_pos_h;
-                        const int hi_2_h      = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3;
-
-                        const int q_l = ((hi_2_l << 4) | l_4) - 32;
-                        const int q_h = ((hi_2_h << 4) | hi_4) - 32;
-
-                        const int8_t a_l = a_ptr[l].qs[base_l + i];
-                        const int8_t a_h = a_ptr[l].qs[base_h + i];
-
-                        sumi_l += q_l * a_l;
-                        sumi_h += q_h * a_h;
-                    }
-
-                    sumf[j] +=
-                        (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
-                }
-            }
-        }
+void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    ggml_gemv_q6_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc);
+}
 
-        for (int j = 0; j < ncols_interleaved; j++) {
-            s[x * ncols_interleaved + j] = sumf[j];
-        }
-    }
+void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    ggml_gemv_q6_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc);
 }
 
 void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@@ -1485,109 +1597,12 @@ void ggml_gemm_q5_K_8x8_q8_K_generic(int                        n,
     }
 }
 
-void ggml_gemm_q6_K_8x8_q8_K_generic(int                        n,
-                                     float * GGML_RESTRICT      s,
-                                     size_t                     bs,
-                                     const void * GGML_RESTRICT vx,
-                                     const void * GGML_RESTRICT vy,
-                                     int                        nr,
-                                     int                        nc) {
-    const int qk                = QK_K;
-    const int nb                = n / qk;
-    const int ncols_interleaved = 8;
-    const int blocklen          = 8;
-
-    assert(n % qk == 0);
-    assert(nr % 4 == 0);
-    assert(nc % ncols_interleaved == 0);
-
-    UNUSED(bs);
-
-    float sumf[4][8];
-
-    for (int y = 0; y < nr / 4; y++) {
-        const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
-        for (int x = 0; x < nc / ncols_interleaved; x++) {
-            const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb);
-
-            for (int m = 0; m < 4; m++) {
-                for (int j = 0; j < ncols_interleaved; j++) {
-                    sumf[m][j] = 0.0f;
-                }
-            }
-
-            for (int l = 0; l < nb; l++) {
-                for (int k = 0; k < 16; k++) {
-                    // k = 0.. 7 weights 0-63 low, 64-127 high
-                    // k = 8..15 weights 128-191 low, 192-255 high
-                    const int base_l = (k / 8) * 128 + (k % 8) * 8;
-                    const int base_h = base_l + 64;
-
-                    const int scale_idx_l = base_l / 16;
-                    const int scale_idx_h = base_h / 16;
-
-                    // Bit shift cycles 0,2,4,6 for each 32-value group within a 128-value half
-                    const int qh_shift_l = ((base_l % 128) / 32) * 2;
-                    const int qh_shift_h = ((base_h % 128) / 32) * 2;
-
-                    // qh_half: offset to the correct 32-byte half (0 or 32)
-                    const int qh_half_l = (base_l / 128) * 32;
-                    const int qh_half_h = (base_h / 128) * 32;
-
-                    // Activation base indices for q8_Kx4 interleaved format
-                    // Layout: 128-value halves (k/8), then 8-value sub-blocks (k%8) with stride 32
-                    const int q8_base = (k / 8) * 512 + (k % 8) * 32;
-
-                    for (int m = 0; m < 4; m++) {
-                        for (int j = 0; j < ncols_interleaved; j++) {
-                            // Interleaved scales
-                            const int8_t scale_l = b_ptr[l].scales[scale_idx_l * 8 + j];
-                            const int8_t scale_h = b_ptr[l].scales[scale_idx_h * 8 + j];
-
-                            int sumi_l = 0;
-                            int sumi_h = 0;
-
-                            for (int i = 0; i < blocklen; i++) {
-                                const int ql_pos = k * 64 + j * 8 + i;
-                                const int l_4    = b_ptr[l].ql[ql_pos] & 0xF;
-                                const int hi_4   = (b_ptr[l].ql[ql_pos] >> 4) & 0xF;
-
-                                const int qh_idx_l    = qh_half_l + ((base_l + i) % 32);
-                                const int qh_chunk_l  = qh_idx_l / 8;
-                                const int qh_pos_l    = qh_idx_l % 8;
-                                const int qh_offset_l = qh_chunk_l * 64 + j * 8 + qh_pos_l;
-                                const int hi_2_l      = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3;
-
-                                const int qh_idx_h    = qh_half_h + ((base_h + i) % 32);
-                                const int qh_chunk_h  = qh_idx_h / 8;
-                                const int qh_pos_h    = qh_idx_h % 8;
-                                const int qh_offset_h = qh_chunk_h * 64 + j * 8 + qh_pos_h;
-                                const int hi_2_h      = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3;
-
-                                const int q_l = ((hi_2_l << 4) | l_4) - 32;
-                                const int q_h = ((hi_2_h << 4) | hi_4) - 32;
-
-                                const int8_t q8_l = a_ptr[l].qs[q8_base + m * 8 + i];
-                                const int8_t q8_h = a_ptr[l].qs[q8_base + m * 8 + i + 256];
-
-                                sumi_l += q_l * q8_l;
-                                sumi_h += q_h * q8_h;
-                            }
-
-                            sumf[m][j] += (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) *
-                                          a_ptr[l].d[m];
-                        }
-                    }
-                }
-            }
+void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    ggml_gemm_q6_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc);
+}
 
-            for (int m = 0; m < 4; m++) {
-                for (int j = 0; j < ncols_interleaved; j++) {
-                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
-                }
-            }
-        }
-    }
+void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+   ggml_gemm_q6_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc);
 }
 
 void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@@ -2097,18 +2112,18 @@ static block_q6_Kx8 make_block_q6_Kx8(block_q6_K * in, unsigned int blck_size_in
     }
 
     const int end_ls = QK_K * 4 / blck_size_interleave;
-    // Interleave Q6_K quants by taking 8 bytes at a time
+    // Interleave Q6_K quants by taking blck_size_interleave bytes at a time
     for (int i = 0; i < end_ls; ++i) {
         int src_id     = i % n_blocks;
         int src_offset = (i / n_blocks) * blck_size_interleave;
         int dst_offset = i * blck_size_interleave;
 
         uint64_t elem_ls;
-        memcpy(&elem_ls, &in[src_id].ql[src_offset], sizeof(uint64_t));
-        memcpy(&out.ql[dst_offset], &elem_ls, sizeof(uint64_t));
+        memcpy(&elem_ls, &in[src_id].ql[src_offset], blck_size_interleave);
+        memcpy(&out.ql[dst_offset], &elem_ls, blck_size_interleave);
     }
 
-    // Interleave high bits using same 8-byte pattern as low bits
+    // Interleave high bits using same chunk size as low bits
     const int end_hs = end_ls / 2;
     for (int i = 0; i < end_hs; ++i) {
         int src_id     = i % n_blocks;
@@ -2116,8 +2131,8 @@ static block_q6_Kx8 make_block_q6_Kx8(block_q6_K * in, unsigned int blck_size_in
         int dst_offset = i * blck_size_interleave;
 
         uint64_t elem_hs;
-        memcpy(&elem_hs, &in[src_id].qh[src_offset], sizeof(uint64_t));
-        memcpy(&out.qh[dst_offset], &elem_hs, sizeof(uint64_t));
+        memcpy(&elem_hs, &in[src_id].qh[src_offset], blck_size_interleave);
+        memcpy(&out.qh[dst_offset], &elem_hs, blck_size_interleave);
     }
 
     // The below logic is designed so as to unpack and rearrange scales in Q6_K
@@ -2262,7 +2277,7 @@ static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor *       t,
 
 static int repack_q6_K_to_q6_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
     GGML_ASSERT(t->type == GGML_TYPE_Q6_K);
-    GGML_ASSERT(interleave_block == 8);
+    GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
     constexpr int nrows_interleaved = 8;
 
     block_q6_Kx8 * dst = (block_q6_Kx8 *)t->data;
@@ -2511,6 +2526,10 @@ template <> int repack<block_q5_K, 8, 8>(struct ggml_tensor * t, const void * da
     return repack_q5_K_to_q5_K_8_bl(t, 8, data, data_size);
 }
 
+template <> int repack<block_q6_K, 4, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
+    return repack_q6_K_to_q6_K_8_bl(t, 4, data, data_size);
+}
+
 template <> int repack<block_q6_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
     return repack_q6_K_to_q6_K_8_bl(t, 8, data, data_size);
 }
@@ -2575,6 +2594,10 @@ template <> void gemv<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t
     ggml_gemv_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
 }
 
+template <> void gemv<block_q6_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemv_q6_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
+}
+
 template <> void gemv<block_q6_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
     ggml_gemv_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
 }
@@ -2634,6 +2657,10 @@ template <> void gemm<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t
     ggml_gemm_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
 }
 
+template <> void gemm<block_q6_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_q6_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
+}
+
 template <> void gemm<block_q6_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
     ggml_gemm_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
 }
@@ -3043,6 +3070,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
     static const ggml::cpu::repack::tensor_traits<block_q5_K, 8, 8, GGML_TYPE_Q8_K> q5_K_8x8_q8_K;
 
     // instance for Q6_K
+    static const ggml::cpu::repack::tensor_traits<block_q6_K, 4, 8, GGML_TYPE_Q8_K> q6_K_8x4_q8_K;
     static const ggml::cpu::repack::tensor_traits<block_q6_K, 8, 8, GGML_TYPE_Q8_K> q6_K_8x8_q8_K;
 
     // instance for Q2
@@ -3107,6 +3135,11 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
                 return &q6_K_8x8_q8_K;
             }
         }
+        if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
+            if (cur->ne[1] % 8 == 0) {
+                return &q6_K_8x4_q8_K;
+            }
+        }
     } else if (cur->type == GGML_TYPE_IQ4_NL) {
         if (ggml_cpu_has_avx2()) {
             if (cur->ne[1] % 8 == 0) {
index 855320eeeb6fe2428b4f4dccf6566ff3c66b6719..39b6b482388ce8dc7981e4bb8c4141b271c97b79 100644 (file)
@@ -112,6 +112,7 @@ void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
 void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
@@ -122,6 +123,7 @@ void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
 void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
@@ -142,6 +144,7 @@ void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
 void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
@@ -152,6 +155,7 @@ void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
 void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);