]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
metal : small-batch mat-mul kernels (llama/10581)
authorGeorgi Gerganov <redacted>
Tue, 3 Dec 2024 09:52:33 +0000 (11:52 +0200)
committerGeorgi Gerganov <redacted>
Sun, 8 Dec 2024 18:14:35 +0000 (20:14 +0200)
* metal : small-batch mat-mul kernels

ggml-ci

* metal : add rest of types

ggml-ci

* metal : final adjustments

ggml-ci

* metal : add comments

ggml-ci

ggml/src/ggml-metal/ggml-metal-impl.h
ggml/src/ggml-metal/ggml-metal.m
ggml/src/ggml-metal/ggml-metal.metal

index 53c13549650c80f3db6e7538161bbe738abb8d28..e195867192712b36b2c20b16a97453a0ad542339 100644 (file)
@@ -192,6 +192,30 @@ typedef struct {
     int16_t  r3;
 } ggml_metal_kargs_mul_mv;
 
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne10;
+    int32_t  ne11;
+    int32_t  ne12;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    int32_t  ne0;
+    int32_t  ne1;
+    int16_t  r2;
+    int16_t  r3;
+    int16_t  nsg;
+    int16_t  nxpsg;
+    int16_t  r1ptg;
+} ggml_metal_kargs_mul_mv_ext;
+
 typedef struct {
     int32_t  nei0;
     int32_t  nei1;
index 043a455e272b91ef21e63c727dc57f6f2155919c..b699f6fd094e35f3eb69488310c9351cbece3c4e 100644 (file)
@@ -175,6 +175,46 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5,
     GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
@@ -702,6 +742,46 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,               mul_mv_q5_0_f32,                has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,               mul_mv_q5_1_f32,                has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,               mul_mv_q8_0_f32,                has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,       mul_mv_ext_f16_f32_r1_2,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,       mul_mv_ext_f16_f32_r1_3,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,       mul_mv_ext_f16_f32_r1_4,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5,       mul_mv_ext_f16_f32_r1_5,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2,      mul_mv_ext_q4_0_f32_r1_2,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3,      mul_mv_ext_q4_0_f32_r1_3,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4,      mul_mv_ext_q4_0_f32_r1_4,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5,      mul_mv_ext_q4_0_f32_r1_5,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2,      mul_mv_ext_q4_1_f32_r1_2,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3,      mul_mv_ext_q4_1_f32_r1_3,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4,      mul_mv_ext_q4_1_f32_r1_4,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5,      mul_mv_ext_q4_1_f32_r1_5,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2,      mul_mv_ext_q5_0_f32_r1_2,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3,      mul_mv_ext_q5_0_f32_r1_3,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4,      mul_mv_ext_q5_0_f32_r1_4,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5,      mul_mv_ext_q5_0_f32_r1_5,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2,      mul_mv_ext_q5_1_f32_r1_2,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3,      mul_mv_ext_q5_1_f32_r1_3,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4,      mul_mv_ext_q5_1_f32_r1_4,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5,      mul_mv_ext_q5_1_f32_r1_5,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2,      mul_mv_ext_q8_0_f32_r1_2,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,      mul_mv_ext_q8_0_f32_r1_3,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,      mul_mv_ext_q8_0_f32_r1_4,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,      mul_mv_ext_q8_0_f32_r1_5,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2,      mul_mv_ext_q4_K_f32_r1_2,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3,      mul_mv_ext_q4_K_f32_r1_3,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4,      mul_mv_ext_q4_K_f32_r1_4,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5,      mul_mv_ext_q4_K_f32_r1_5,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2,      mul_mv_ext_q5_K_f32_r1_2,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3,      mul_mv_ext_q5_K_f32_r1_3,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4,      mul_mv_ext_q5_K_f32_r1_4,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5,      mul_mv_ext_q5_K_f32_r1_5,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2,      mul_mv_ext_q6_K_f32_r1_2,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3,      mul_mv_ext_q6_K_f32_r1_3,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4,      mul_mv_ext_q6_K_f32_r1_4,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5,      mul_mv_ext_q6_K_f32_r1_5,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2,    mul_mv_ext_iq4_nl_f32_r1_2,     has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3,    mul_mv_ext_iq4_nl_f32_r1_3,     has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4,    mul_mv_ext_iq4_nl_f32_r1_4,     has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5,    mul_mv_ext_iq4_nl_f32_r1_5,     has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,               mul_mv_q2_K_f32,                has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,               mul_mv_q3_K_f32,                has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,               mul_mv_q4_K_f32,                has_simdgroup_reduction);
@@ -1936,30 +2016,180 @@ static void ggml_metal_encode_node(
 
                 // find the break-even point where the matrix-matrix kernel becomes more efficient compared
                 // to the matrix-vector kernel
-                int ne11_mm_min = 4;
+                const int ne11_mm_min = 4;
+
+                // first try to use small-batch mat-mv kernels
+                // these should be efficient for BS [2, ~8]
+                if (src1t == GGML_TYPE_F32 && (ne00%256 == 0) &&
+                    (
+                     (
+                      (
+                       src0t == GGML_TYPE_F16  || // TODO: helper function
+                       src0t == GGML_TYPE_Q4_0 ||
+                       src0t == GGML_TYPE_Q4_1 ||
+                       src0t == GGML_TYPE_Q5_0 ||
+                       src0t == GGML_TYPE_Q5_1 ||
+                       src0t == GGML_TYPE_Q8_0 ||
+                       src0t == GGML_TYPE_IQ4_NL ||
+                       false) && (ne11 >= 2 && ne11 <= 8)
+                     ) ||
+                     (
+                      (
+                       src0t == GGML_TYPE_Q4_K ||
+                       src0t == GGML_TYPE_Q5_K ||
+                       src0t == GGML_TYPE_Q6_K ||
+                       false) && (ne11 >= 4 && ne11 <= 8)
+                     )
+                    )
+                   ) {
+                    // TODO: determine the optimal parameters based on grid utilization
+                    //       I still don't know why we should not always use the maximum available threads:
+                    //
+                    //       nsg = pipeline.maxTotalThreadsPerThreadgroup / 32
+                    //
+                    //       my current hypothesis is that the work grid is not evenly divisible for different nsg
+                    //       values and there can be some tail effects when nsg is high. need to confirm this
+                    //
+                    const int nsg    = 2;                 // num simdgroups per threadgroup
+                    const int nxpsg  = ne11 < 3 ? 16 : 8; // num threads along row per simdgroup
+                    const int nypsg  = 32/nxpsg;          // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time)
+                    const int r0ptg  = nypsg*nsg;         // num src0 rows per threadgroup
+                          int r1ptg  = 4;                 // num src1 rows per threadgroup
+
+                    // note: not sure how optimal are those across all different hardware. there might be someting cleverer
+                    switch (ne11) {
+                        case 2:
+                            r1ptg = 2; break;
+                        case 3:
+                        case 6:
+                            r1ptg = 3; break;
+                        case 4:
+                        case 7:
+                        case 8:
+                            r1ptg = 4; break;
+                        case 5:
+                            r1ptg = 5; break;
+                    };
 
-#if 0
-                // the numbers below are measured on M2 Ultra for 7B and 13B models
-                // these numbers do not translate to other devices or model sizes
-                // TODO: need to find a better approach
-                        if ([device.name isEqualToString:@"Apple M2 Ultra"]) {
-                            switch (src0t) {
-                                case GGML_TYPE_F16:  ne11_mm_min = 2;  break;
-                                case GGML_TYPE_Q8_0: ne11_mm_min = 7;  break;
-                                case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
-                                case GGML_TYPE_Q3_K: ne11_mm_min = 7;  break;
-                                case GGML_TYPE_Q4_0:
-                                case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
-                                case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
-                                case GGML_TYPE_Q5_0:                          // not tested yet
-                                case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
-                                case GGML_TYPE_Q5_K: ne11_mm_min = 7;  break;
-                                case GGML_TYPE_Q6_K: ne11_mm_min = 7;  break;
-                                default:             ne11_mm_min = 1;  break;
-                            }
-                        }
-#endif
+                    id<MTLComputePipelineState> pipeline = nil;
+
+                    switch (src0->type) {
+                        case GGML_TYPE_F16:
+                            switch (r1ptg) {
+                                case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2].pipeline; break;
+                                case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3].pipeline; break;
+                                case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4].pipeline; break;
+                                case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            } break;
+                        case GGML_TYPE_Q4_0:
+                            switch (r1ptg) {
+                                case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2].pipeline; break;
+                                case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3].pipeline; break;
+                                case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4].pipeline; break;
+                                case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            } break;
+                        case GGML_TYPE_Q4_1:
+                            switch (r1ptg) {
+                                case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2].pipeline; break;
+                                case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3].pipeline; break;
+                                case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4].pipeline; break;
+                                case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            } break;
+                        case GGML_TYPE_Q5_0:
+                            switch (r1ptg) {
+                                case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2].pipeline; break;
+                                case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3].pipeline; break;
+                                case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4].pipeline; break;
+                                case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            } break;
+                        case GGML_TYPE_Q5_1:
+                            switch (r1ptg) {
+                                case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2].pipeline; break;
+                                case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3].pipeline; break;
+                                case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4].pipeline; break;
+                                case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            } break;
+                        case GGML_TYPE_Q8_0:
+                            switch (r1ptg) {
+                                case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2].pipeline; break;
+                                case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3].pipeline; break;
+                                case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4].pipeline; break;
+                                case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            } break;
+                        case GGML_TYPE_Q4_K:
+                            switch (r1ptg) {
+                                case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2].pipeline; break;
+                                case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3].pipeline; break;
+                                case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4].pipeline; break;
+                                case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_5].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            } break;
+                        case GGML_TYPE_Q5_K:
+                            switch (r1ptg) {
+                                case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_2].pipeline; break;
+                                case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_3].pipeline; break;
+                                case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_4].pipeline; break;
+                                case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_K_F32_R1_5].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            } break;
+                        case GGML_TYPE_Q6_K:
+                            switch (r1ptg) {
+                                case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_2].pipeline; break;
+                                case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_3].pipeline; break;
+                                case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_4].pipeline; break;
+                                case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q6_K_F32_R1_5].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            } break;
+                        case GGML_TYPE_IQ4_NL:
+                            switch (r1ptg) {
+                                case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_2].pipeline; break;
+                                case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_3].pipeline; break;
+                                case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_4].pipeline; break;
+                                case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_IQ4_NL_F32_R1_5].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            } break;
+                        default: GGML_ABORT("not implemented");
+                    }
+
+                    ggml_metal_kargs_mul_mv_ext args = {
+                        /*.ne00  =*/ ne00,
+                        /*.ne01  =*/ ne01,
+                        /*.ne02  =*/ ne02,
+                        /*.nb00  =*/ nb00,
+                        /*.nb01  =*/ nb01,
+                        /*.nb02  =*/ nb02,
+                        /*.nb03  =*/ nb03,
+                        /*.ne10  =*/ ne10,
+                        /*.ne11  =*/ ne11,
+                        /*.ne12  =*/ ne12,
+                        /*.nb10  =*/ nb10,
+                        /*.nb11  =*/ nb11,
+                        /*.nb12  =*/ nb12,
+                        /*.nb13  =*/ nb13,
+                        /*.ne0   =*/ ne0,
+                        /*.ne1   =*/ ne1,
+                        /*.r2    =*/ r2,
+                        /*.r3    =*/ r3,
+                        /*.nsg   =*/ nsg,
+                        /*.nxpsg =*/ nxpsg,
+                        /*.r1ptg =*/ r1ptg,
+                    };
+
+                    [encoder setComputePipelineState:pipeline];
+                    [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                    [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:3];
 
+                    //printf("ne01 = %lld nr0ptg = %d\n", ne01, nr0ptg);
+                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + r0ptg - 1)/r0ptg, (ne11 + r1ptg - 1)/r1ptg, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
+                } else
                 // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
                 // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
                 if ([device supportsFamily:MTLGPUFamilyApple7] &&
index 560f0caf3944c38283f96e0da6fb19ad32546cf3..5caa0846a8b537dfce529da4b1a9e87a8674e036 100644 (file)
@@ -47,6 +47,11 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
     reg = (type4x4)(*src);
 }
 
+template <typename type4>
+void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) {
+    reg = (type4)(*(src + il));
+}
+
 #if defined(GGML_METAL_USE_BF16)
 template <typename type4x4>
 void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
@@ -55,7 +60,7 @@ void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & re
 #endif
 
 template <typename type4x4>
-void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
+void dequantize_q4_0(device const block_q4_0 * xb, short il, thread type4x4 & reg) {
     device const uint16_t * qs = ((device const uint16_t *)xb + 1);
     const float d1 = il ? (xb->d / 16.h) : xb->d;
     const float d2 = d1 / 256.f;
@@ -73,8 +78,23 @@ void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg
     reg = (type4x4) reg_f;
 }
 
+template <typename type4>
+void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & reg) {
+    device const uint16_t * qs = ((device const uint16_t *)xb + 1);
+    const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
+    const float d2 = d1 / 256.f;
+    const float md = -8.h * xb->d;
+    const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
+    const ushort mask1 = mask0 << 8;
+
+    for (int i = 0; i < 2; i++) {
+        reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + md;
+        reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + md;
+    }
+}
+
 template <typename type4x4>
-void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
+void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
     device const uint16_t * qs = ((device const uint16_t *)xb + 2);
     const float d1 = il ? (xb->d / 16.h) : xb->d;
     const float d2 = d1 / 256.f;
@@ -92,8 +112,23 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
     reg = (type4x4) reg_f;
 }
 
+template <typename type4>
+void dequantize_q4_1_t4(device const block_q4_1 * xb, short il, thread type4 & reg) {
+    device const uint16_t * qs = ((device const uint16_t *)xb + 2);
+    const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
+    const float d2 = d1 / 256.f;
+    const float  m = xb->m;
+    const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
+    const ushort mask1 = mask0 << 8;
+
+    for (int i = 0; i < 2; i++) {
+        reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + m;
+        reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + m;
+    }
+}
+
 template <typename type4x4>
-void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
+void dequantize_q5_0(device const block_q5_0 * xb, short il, thread type4x4 & reg) {
     device const uint16_t * qs = ((device const uint16_t *)xb + 3);
     const float d = xb->d;
     const float md = -16.h * xb->d;
@@ -124,8 +159,38 @@ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg
     reg = (type4x4) reg_f;
 }
 
+template <typename type4>
+void dequantize_q5_0_t4(device const block_q5_0 * xb, short il, thread type4 & reg) {
+    device const uint16_t * qs = ((device const uint16_t *)xb + 3);
+    const float d = xb->d;
+    const float md = -16.h * xb->d;
+    const ushort mask = (il/4) ? 0x00F0 : 0x000F;
+
+    const uint32_t qh = *((device const uint32_t *)xb->qh);
+
+    const int x_mv = (il/4) ? 4 : 0;
+
+    const int gh_mv = (il/4) ? 12 : 0;
+    const int gh_bk = (il/4) ?  0 : 4;
+
+    for (int ii = 0; ii < 2; ii++) {
+        int i = 2*(il%4) + ii;
+
+        // extract the 5-th bits for x0 and x1
+        const uint8_t xh_0 = ((qh >> (gh_mv + 2*i  )) << gh_bk) & 0x10;
+        const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
+
+        // combine the 4-bits from qs with the 5th bit
+        const int32_t x0 = ((((qs[i]     ) & mask) >> x_mv) | xh_0);
+        const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
+
+        reg[2*ii + 0] = d * x0 + md;
+        reg[2*ii + 1] = d * x1 + md;
+    }
+}
+
 template <typename type4x4>
-void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
+void dequantize_q5_1(device const block_q5_1 * xb, short il, thread type4x4 & reg) {
     device const uint16_t * qs = ((device const uint16_t *)xb + 4);
     const float d = xb->d;
     const float m = xb->m;
@@ -156,10 +221,40 @@ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg
     reg = (type4x4) reg_f;
 }
 
+template <typename type4>
+void dequantize_q5_1_t4(device const block_q5_1 * xb, short il, thread type4 & reg) {
+    device const uint16_t * qs = ((device const uint16_t *)xb + 4);
+    const float d = xb->d;
+    const float m = xb->m;
+    const ushort mask = (il/4) ? 0x00F0 : 0x000F;
+
+    const uint32_t qh = *((device const uint32_t *)xb->qh);
+
+    const int x_mv = (il/4) ? 4 : 0;
+
+    const int gh_mv = (il/4) ? 12 : 0;
+    const int gh_bk = (il/4) ?  0 : 4;
+
+    for (int ii = 0; ii < 2; ii++) {
+        int i = 2*(il%4) + ii;
+
+        // extract the 5-th bits for x0 and x1
+        const uint8_t xh_0 = ((qh >> (gh_mv + 2*i  )) << gh_bk) & 0x10;
+        const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
+
+        // combine the 4-bits from qs with the 5th bit
+        const int32_t x0 = ((((qs[i]     ) & mask) >> x_mv) | xh_0);
+        const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
+
+        reg[2*ii + 0] = d * x0 + m;
+        reg[2*ii + 1] = d * x1 + m;
+    }
+}
+
 template <typename type4x4>
 void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
     device const int8_t * qs = ((device const int8_t *)xb->qs);
-    const half d = xb->d;
+    const float d = xb->d;
 
     float4x4 reg_f;
 
@@ -170,6 +265,16 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
     reg = (type4x4) reg_f;
 }
 
+template <typename type4>
+void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & reg) {
+    device const int8_t * qs = ((device const int8_t *)xb->qs);
+    const float d = xb->d;
+
+    for (int i = 0; i < 4; i++) {
+        reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d);
+    }
+}
+
 template <typename type4x4>
 void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
     const float d = xb->d;
@@ -224,7 +329,7 @@ static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q
 }
 
 template <typename type4x4>
-void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
+void dequantize_q4_K(device const block_q4_K * xb, short il, thread type4x4 & reg) {
     device const uchar * q = xb->qs;
 
     short is = (il/4) * 2;
@@ -236,7 +341,7 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
     const float dl = d * sc[0];
     const float ml = min * sc[1];
 
-    const ushort mask = il<2 ? 0x0F : 0xF0;
+    const ushort mask = il < 2 ? 0x0F : 0xF0;
     for (int i = 0; i < 16; ++i) {
         reg[i/4][i%4] = dl * (q[i] & mask) - ml;
     }
@@ -469,6 +574,19 @@ void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4
     }
 }
 
+template <typename type4>
+void dequantize_iq4_nl_t4(device const block_iq4_nl * xb, short il, thread type4 & reg) {
+    device const uint16_t * q4 = (device const uint16_t *)xb->qs;
+    const float d = xb->d;
+    uint32_t aux32;
+    thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
+    aux32 = ((q4[2*(il%4)] | (q4[2*(il%4)+1] << 16)) >> 4*(il/4)) & 0x0f0f0f0f;
+    reg[0] = d * kvalues_iq4nl_f[q8[0]];
+    reg[1] = d * kvalues_iq4nl_f[q8[1]];
+    reg[2] = d * kvalues_iq4nl_f[q8[2]];
+    reg[3] = d * kvalues_iq4nl_f[q8[3]];
+}
+
 template <typename type4x4>
 void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
     // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
@@ -1809,6 +1927,301 @@ kernel void kernel_mul_mv_q8_0_f32(
     kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
+// mat-vec kernel processing in chunks of float4
+// chpb - chunks per quantization block
+template<short nxpsg, short r1ptg, typename q_t, short chpb, void (*deq_t4)(device const q_t *, short, thread float4 &) >
+void kernel_mul_mv_ext_q4_f32_impl(
+        constant ggml_metal_kargs_mul_mv_ext & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
+    const short chpt = 4; // chunks per thread
+
+  //const short nxpsg = (32);
+    const short nypsg = (32/nxpsg);
+
+    const short tx = tiisg%nxpsg;
+    const short ty = tiisg/nxpsg;
+
+    const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty;
+    const int i11 = tgpig.y*r1ptg;
+    const int i1m = tgpig.z;
+
+    const int i12 = i1m%args.ne12;
+    const int i13 = i1m/args.ne12;
+
+    const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 = i11*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;
+
+    device const float4 * y4[r1ptg];
+
+    for (int ir1 = 0; ir1 < r1ptg; ++ir1) {
+        y4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4 *) src1;
+    }
+
+    float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f };
+
+    short cch = tx%chpb; // current chunk index
+
+    for (int ich = tx; 4*ich < args.ne00; ich += chpt*nxpsg) {
+        float4 lx[chpt];
+
+#pragma unroll(chpt)
+        for (short ch = 0; ch < chpt; ++ch) {
+            deq_t4(xq, cch, lx[ch]);
+
+            cch += nxpsg;
+            if (cch >= chpb) {
+                xq  += cch/chpb;
+                cch %= chpb;
+            }
+        }
+
+#pragma unroll(chpt)
+        for (short ch = 0; ch < chpt; ++ch) {
+#pragma unroll(r1ptg)
+            for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
+                sumf[ir1] += dot(lx[ch], y4[ir1][ch*nxpsg]);
+
+            }
+        }
+
+#pragma unroll(r1ptg)
+        for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
+            y4[ir1] += chpt*nxpsg;
+        }
+    }
+
+    // reduce only the threads in each row
+    for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
+        if (nxpsg >= 32) {
+            sumf[ir1] += simd_shuffle_down(sumf[ir1], 16);
+        }
+        if (nxpsg >= 16) {
+            sumf[ir1] += simd_shuffle_down(sumf[ir1],  8);
+        }
+        if (nxpsg >= 8) {
+            sumf[ir1] += simd_shuffle_down(sumf[ir1],  4);
+        }
+        if (nxpsg >= 4) {
+            sumf[ir1] += simd_shuffle_down(sumf[ir1],  2);
+        }
+        if (nxpsg >= 2) {
+            sumf[ir1] += simd_shuffle_down(sumf[ir1],  1);
+        }
+
+        //sumf[ir1] = simd_sum(sumf[ir1]);
+    }
+
+    if (tx == 0) {
+        for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) {
+            device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0;
+
+            if (i01 < args.ne01) {
+                dst_f32[i01] = sumf[ir1];
+            }
+        }
+    }
+}
+
+// mat-vec kernel processing in chunks of float4x4
+template<short nxpsg, short r1ptg, typename q_t, short chpb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &) >
+void kernel_mul_mv_ext_q4x4_f32_impl(
+        constant ggml_metal_kargs_mul_mv_ext & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
+    const short chpt = 1;
+
+  //const short nxpsg = (32);
+    const short nypsg = (32/nxpsg);
+
+    const short tx = tiisg%nxpsg;
+    const short ty = tiisg/nxpsg;
+
+    const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty;
+    const int i11 = tgpig.y*r1ptg;
+    const int i1m = tgpig.z;
+
+    const int i12 = i1m%args.ne12;
+    const int i13 = i1m/args.ne12;
+
+    const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 = i11*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + tx/chpb : (device const q_t *) src0;
+
+    device const float4x4 * y4x4[r1ptg];
+
+    for (int ir1 = 0; ir1 < r1ptg; ++ir1) {
+        y4x4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4x4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4x4 *) src1;
+    }
+
+    float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f };
+
+    short cch = tx%chpb;
+
+    for (int ich = tx; 16*ich < args.ne00; ich += chpt*nxpsg) {
+        float4x4 lx[chpt];
+
+#pragma unroll(chpt)
+        for (short ch = 0; ch < chpt; ++ch) {
+            deq_t4x4(xq, cch, lx[ch]);
+
+            cch += nxpsg;
+            if (cch >= chpb) {
+                xq  += cch/chpb;
+                cch %= chpb;
+            }
+        }
+
+#pragma unroll(chpt)
+        for (short ch = 0; ch < chpt; ++ch) {
+#pragma unroll(r1ptg)
+            for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
+                sumf[ir1] +=
+                    dot(lx[ch][0], y4x4[ir1][ch*nxpsg][0]) +
+                    dot(lx[ch][1], y4x4[ir1][ch*nxpsg][1]) +
+                    dot(lx[ch][2], y4x4[ir1][ch*nxpsg][2]) +
+                    dot(lx[ch][3], y4x4[ir1][ch*nxpsg][3]);
+
+            }
+        }
+
+#pragma unroll(r1ptg)
+        for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
+            y4x4[ir1] += chpt*nxpsg;
+        }
+    }
+
+    for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
+        if (nxpsg >= 32) {
+            sumf[ir1] += simd_shuffle_down(sumf[ir1], 16);
+        }
+        if (nxpsg >= 16) {
+            sumf[ir1] += simd_shuffle_down(sumf[ir1],  8);
+        }
+        if (nxpsg >= 8) {
+            sumf[ir1] += simd_shuffle_down(sumf[ir1],  4);
+        }
+        if (nxpsg >= 4) {
+            sumf[ir1] += simd_shuffle_down(sumf[ir1],  2);
+        }
+        if (nxpsg >= 2) {
+            sumf[ir1] += simd_shuffle_down(sumf[ir1],  1);
+        }
+
+        //sumf[ir1] = simd_sum(sumf[ir1]);
+    }
+
+    if (tx == 0) {
+        for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) {
+            device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0;
+
+            if (i01 < args.ne01) {
+                dst_f32[i01] = sumf[ir1];
+            }
+        }
+    }
+}
+
+// dispatchers needed for compile-time nxpsg
+// epb - elements per quantization block
+template<short r1ptg, typename q_t, short epb, void (*deq_t4)(device const q_t *, short, thread float4 &)>
+kernel void kernel_mul_mv_ext_q4_f32_disp(
+        constant ggml_metal_kargs_mul_mv_ext & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
+    switch (args.nxpsg) {
+        case 4:  kernel_mul_mv_ext_q4_f32_impl<4,  r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+        case 8:  kernel_mul_mv_ext_q4_f32_impl<8,  r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+        case 16: kernel_mul_mv_ext_q4_f32_impl<16, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+        case 32: kernel_mul_mv_ext_q4_f32_impl<32, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+    }
+}
+
+template<short r1ptg, typename q_t, short epb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &)>
+kernel void kernel_mul_mv_ext_q4x4_f32_disp(
+        constant ggml_metal_kargs_mul_mv_ext & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
+    switch (args.nxpsg) {
+        case 4:  kernel_mul_mv_ext_q4x4_f32_impl<4,  r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+        case 8:  kernel_mul_mv_ext_q4x4_f32_impl<8,  r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+        case 16: kernel_mul_mv_ext_q4x4_f32_impl<16, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+        case 32: kernel_mul_mv_ext_q4x4_f32_impl<32, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
+    }
+}
+
+typedef decltype(kernel_mul_mv_ext_q4_f32_disp  <2, block_q8_0, 32,  dequantize_q8_0_t4>) mul_mv_ext_q4_f32_t;
+typedef decltype(kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>)    mul_mv_ext_q4x4_f32_t;
+
+template [[host_name("kernel_mul_mv_ext_f16_f32_r1_2")]]    kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, half4,        4,  dequantize_f16_t4>;
+template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]]    kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, half4,        4,  dequantize_f16_t4>;
+template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]]    kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4,        4,  dequantize_f16_t4>;
+template [[host_name("kernel_mul_mv_ext_f16_f32_r1_5")]]    kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, half4,        4,  dequantize_f16_t4>;
+
+template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_0,   32, dequantize_q4_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_0,   32, dequantize_q4_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_0,   32, dequantize_q4_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_5")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_0,   32, dequantize_q4_0_t4>;
+
+template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_2")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q4_1,   32, dequantize_q4_1_t4>;
+template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_3")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q4_1,   32, dequantize_q4_1_t4>;
+template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_4")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q4_1,   32, dequantize_q4_1_t4>;
+template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_5")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q4_1,   32, dequantize_q4_1_t4>;
+
+template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_2")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_0,   32, dequantize_q5_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_3")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_0,   32, dequantize_q5_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_4")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_0,   32, dequantize_q5_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_5")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_0,   32, dequantize_q5_0_t4>;
+
+template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_2")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q5_1,   32, dequantize_q5_1_t4>;
+template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_3")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q5_1,   32, dequantize_q5_1_t4>;
+template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_4")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q5_1,   32, dequantize_q5_1_t4>;
+template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_5")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q5_1,   32, dequantize_q5_1_t4>;
+
+template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_2")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_q8_0,   32, dequantize_q8_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_3")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_q8_0,   32, dequantize_q8_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_4")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q8_0,   32, dequantize_q8_0_t4>;
+template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q8_0,   32, dequantize_q8_0_t4>;
+
+template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
+template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
+template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
+template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
+
+template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>;
+template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q4_K, 256, dequantize_q4_K>;
+template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q4_K, 256, dequantize_q4_K>;
+template [[host_name("kernel_mul_mv_ext_q4_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q4_K, 256, dequantize_q4_K>;
+
+template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q5_K, 256, dequantize_q5_K>;
+template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q5_K, 256, dequantize_q5_K>;
+template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q5_K, 256, dequantize_q5_K>;
+template [[host_name("kernel_mul_mv_ext_q5_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q5_K, 256, dequantize_q5_K>;
+
+template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_2")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<2, block_q6_K, 256, dequantize_q6_K>;
+template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<3, block_q6_K, 256, dequantize_q6_K>;
+template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>;
+template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>;
+
 #define N_MV_T_T 4
 
 template<typename T0, typename T04, typename T1, typename T14, typename args_t>