]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : refactor mat-vec code (#12569)
authorGeorgi Gerganov <redacted>
Wed, 26 Mar 2025 19:38:38 +0000 (21:38 +0200)
committerGitHub <redacted>
Wed, 26 Mar 2025 19:38:38 +0000 (21:38 +0200)
* metal : refactor mat-vec code

ggml-ci

* metal : rename all_sum -> sum_all

ggml-ci

* metal : fix comments [no ci]

* metal : fix nr constant [no ci]

* metal : mv q6_K support nr0 > 1

ggml-ci

* metal : reduce register pressure

ggml-ci

* metal : fix typo [no ci]

* metal : reduce register pressure

ggml-ci

ggml/src/ggml-metal/ggml-metal-impl.h
ggml/src/ggml-metal/ggml-metal.m
ggml/src/ggml-metal/ggml-metal.metal

index 1e954b4ceabd7c6685caf3a025678e31f2492f93..ca5a00b0322e1d1cbf64f5d137b349f85972b136 100644 (file)
@@ -1,6 +1,70 @@
 #ifndef GGML_METAL_IMPL
 #define GGML_METAL_IMPL
 
+// kernel parameters for mat-vec threadgroups
+//
+// N_R0: number of src0 rows to process per simdgroup
+// N_SG: number of simdgroups per threadgroup
+//
+// TODO: for optimal performance, become function of the device and work size
+
+#define N_R0_Q4_0 4
+#define N_SG_Q4_0 2
+
+#define N_R0_Q4_1 4
+#define N_SG_Q4_1 2
+
+#define N_R0_Q5_0 4
+#define N_SG_Q5_0 2
+
+#define N_R0_Q5_1 4
+#define N_SG_Q5_1 2
+
+#define N_R0_Q8_0 4
+#define N_SG_Q8_0 2
+
+#define N_R0_Q2_K 4
+#define N_SG_Q2_K 2
+
+#define N_R0_Q3_K 2
+#define N_SG_Q3_K 2
+
+#define N_R0_Q4_K 4
+#define N_SG_Q4_K 2
+
+#define N_R0_Q5_K 2
+#define N_SG_Q5_K 2
+
+#define N_R0_Q6_K 1
+#define N_SG_Q6_K 2
+
+#define N_R0_IQ1_S 4
+#define N_SG_IQ1_S 2
+
+#define N_R0_IQ1_M 4
+#define N_SG_IQ1_M 2
+
+#define N_R0_IQ2_XXS 4
+#define N_SG_IQ2_XXS 2
+
+#define N_R0_IQ2_XS 4
+#define N_SG_IQ2_XS 2
+
+#define N_R0_IQ2_S 4
+#define N_SG_IQ2_S 2
+
+#define N_R0_IQ3_XXS 4
+#define N_SG_IQ3_XXS 2
+
+#define N_R0_IQ3_S 4
+#define N_SG_IQ3_S 2
+
+#define N_R0_IQ4_NL 2
+#define N_SG_IQ4_NL 2
+
+#define N_R0_IQ4_XS 2
+#define N_SG_IQ4_XS 2
+
 // kernel argument structs
 //
 // - element counters (e.g. ne00) typically use int32_t to reduce register usage
index af65e7d9f53d43a9ad2b2d9b970144efabec8e2f..195d9678275c9af963064ea52a654ee55ba45dc3 100644 (file)
@@ -2561,171 +2561,180 @@ static void ggml_metal_encode_node(
                     [encoder setThreadgroupMemoryLength:8192 atIndex:0];
                     [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
                 } else {
-                    int nth0 = 32;
-                    int nth1 = 1;
-                    int nrows = 1;
-                    //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
-
                     id<MTLComputePipelineState> pipeline = nil;
 
+                    int nsg = 0; // number of simdgroups
+                    int nr0 = 0; // number of src0 rows per simdgroup
+                    int nr1 = 1; // number of src1 rows per threadgroup
+
+                    size_t smem = 0; // shared memory
+
                     // use custom matrix x vector kernel
                     switch (src0t) {
                         case GGML_TYPE_F32:
                             {
                                 GGML_ASSERT(src1t == GGML_TYPE_F32);
+                                nsg = 1;
+                                nr0 = 1;
+                                nr1 = 4;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
-                                nrows = 4;
                             } break;
                         case GGML_TYPE_F16:
                             {
-                                nth0 = 32;
-                                nth1 = 1;
+                                nsg = 1;
+                                nr0 = 1;
                                 if (src1t == GGML_TYPE_F32) {
                                     if (ne11 * ne12 < 4) {
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
                                     } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
-                                        nrows = ne11;
+                                        nr1 = ne11;
                                     } else {
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
-                                        nrows = 4;
+                                        nr1 = 4;
                                     }
                                 } else {
                                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
-                                    nrows = 4;
+                                    nr1 = 4;
                                 }
                             } break;
                         case GGML_TYPE_BF16:
                             {
-                                nth0 = 32;
-                                nth1 = 1;
+                                nsg = 1;
+                                nr0 = 1;
                                 if (src1t == GGML_TYPE_F32) {
                                     if (ne11 * ne12 < 4) {
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline;
                                     } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline;
-                                        nrows = ne11;
+                                        nr1 = ne11;
                                     } else {
                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline;
-                                        nrows = 4;
+                                        nr1 = 4;
                                     }
                                 } else {
                                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline;
-                                    nrows = 4;
+                                    nr1 = 4;
                                 }
                             } break;
                         case GGML_TYPE_Q4_0:
                             {
-                                nth0 = 8;
-                                nth1 = 8;
+                                nsg = N_SG_Q4_0;
+                                nr0 = N_R0_Q4_0;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q4_1:
                             {
-                                nth0 = 8;
-                                nth1 = 8;
+                                nsg = N_SG_Q4_1;
+                                nr0 = N_R0_Q4_1;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q5_0:
                             {
-                                nth0 = 8;
-                                nth1 = 8;
+                                nsg = N_SG_Q5_0;
+                                nr0 = N_R0_Q5_0;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q5_1:
                             {
-                                nth0 = 8;
-                                nth1 = 8;
+                                nsg = N_SG_Q5_1;
+                                nr0 = N_R0_Q5_1;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q8_0:
                             {
-                                nth0 = 8;
-                                nth1 = 8;
+                                nsg = N_SG_Q8_0;
+                                nr0 = N_R0_Q8_0;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q2_K:
                             {
-                                nth0 = 2;
-                                nth1 = 32;
+                                nsg = N_SG_Q2_K;
+                                nr0 = N_R0_Q2_K;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q3_K:
                             {
-                                nth0 = 2;
-                                nth1 = 32;
+                                nsg = N_SG_Q3_K;
+                                nr0 = N_R0_Q3_K;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q4_K:
                             {
-                                nth0 = 4; //1;
-                                nth1 = 8; //32;
+                                nsg = N_SG_Q4_K;
+                                nr0 = N_R0_Q4_K;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q5_K:
                             {
-                                nth0 = 2;
-                                nth1 = 32;
+                                nsg = N_SG_Q5_K;
+                                nr0 = N_R0_Q5_K;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q6_K:
                             {
-                                nth0 = 2;
-                                nth1 = 32;
+                                nsg = N_SG_Q6_K;
+                                nr0 = N_R0_Q6_K;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
                             } break;
                         case GGML_TYPE_IQ2_XXS:
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ2_XXS;
+                                nr0 = N_R0_IQ2_XXS;
+                                smem = 256*8+128;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
                             } break;
                         case GGML_TYPE_IQ2_XS:
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ2_XS;
+                                nr0 = N_R0_IQ2_XS;
+                                smem = 512*8+128;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
                             } break;
                         case GGML_TYPE_IQ3_XXS:
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ3_XXS;
+                                nr0 = N_R0_IQ3_XXS;
+                                smem = 256*4+128;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline;
                             } break;
                         case GGML_TYPE_IQ3_S:
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ3_S;
+                                nr0 = N_R0_IQ3_S;
+                                smem = 512*4;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline;
                             } break;
                         case GGML_TYPE_IQ2_S:
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ2_S;
+                                nr0 = N_R0_IQ2_S;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline;
                             } break;
                         case GGML_TYPE_IQ1_S:
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ1_S;
+                                nr0 = N_R0_IQ1_S;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline;
                             } break;
                         case GGML_TYPE_IQ1_M:
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ1_M;
+                                nr0 = N_R0_IQ1_M;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline;
                             } break;
                         case GGML_TYPE_IQ4_NL:
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ4_NL;
+                                nr0 = N_R0_IQ4_NL;
+                                smem = 32*sizeof(float);
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline;
                             } break;
                         case GGML_TYPE_IQ4_XS:
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ4_XS;
+                                nr0 = N_R0_IQ4_XS;
+                                smem = 32*sizeof(float);
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline;
                             } break;
                         default:
@@ -2762,41 +2771,10 @@ static void ggml_metal_encode_node(
                     [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
                     [encoder setBuffer:id_dst  offset:offs_dst  atIndex:3];
 
-                    if (src0t == GGML_TYPE_Q4_0  || src0t == GGML_TYPE_Q4_1  || src0t == GGML_TYPE_Q5_0 ||
-                        src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K ||
-                        src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
-                        const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
-                        [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
-                        const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
-                        [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
-                        const int mem_size = 32*sizeof(float);
-                        [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_Q4_K) {
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_Q3_K) {
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_Q5_K) {
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_Q6_K) {
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    } else {
-                        const int64_t ny = (ne11 + nrows - 1)/nrows;
-                        [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    if (smem > 0) {
+                        [encoder setThreadgroupMemoryLength:smem atIndex:0];
                     }
+                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (ne11 + nr1 - 1)/nr1, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
                 }
             } break;
         case GGML_OP_MUL_MAT_ID:
@@ -2902,146 +2880,155 @@ static void ggml_metal_encode_node(
 
                     [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
                 } else {
-                    int nth0 = 32;
-                    int nth1 = 1;
-                    int nrows = 1;
-                    //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
-
                     id<MTLComputePipelineState> pipeline = nil;
 
+                    int nsg = 0; // number of simdgroups
+                    int nr0 = 0; // number of src0 rows per simdgroup
+                    int nr1 = 1; // number of src1 rows per threadgroup
+
+                    size_t smem = 0; // shared memory
+
                     // use custom matrix x vector kernel
                     switch (src0t) {
                         case GGML_TYPE_F32:
                             {
                                 GGML_ASSERT(src1t == GGML_TYPE_F32);
+                                nsg = 1;
+                                nr0 = 1;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
                             } break;
                         case GGML_TYPE_F16:
                             {
                                 GGML_ASSERT(src1t == GGML_TYPE_F32);
-                                nth0 = 32;
-                                nth1 = 1;
+                                nsg = 1;
+                                nr0 = 1;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
                             } break;
                         case GGML_TYPE_BF16:
                             {
                                 GGML_ASSERT(src1t == GGML_TYPE_F32);
-                                nth0 = 32;
-                                nth1 = 1;
+                                nsg = 1;
+                                nr0 = 1;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q4_0:
                             {
-                                nth0 = 8;
-                                nth1 = 8;
+                                nsg = N_SG_Q4_0;
+                                nr0 = N_R0_Q4_0;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q4_1:
                             {
-                                nth0 = 8;
-                                nth1 = 8;
+                                nsg = N_SG_Q4_1;
+                                nr0 = N_R0_Q4_1;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q5_0:
                             {
-                                nth0 = 8;
-                                nth1 = 8;
+                                nsg = N_SG_Q5_0;
+                                nr0 = N_R0_Q5_0;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q5_1:
                             {
-                                nth0 = 8;
-                                nth1 = 8;
+                                nsg = N_SG_Q5_1;
+                                nr0 = N_R0_Q5_1;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q8_0:
                             {
-                                nth0 = 8;
-                                nth1 = 8;
+                                nsg = N_SG_Q8_0;
+                                nr0 = N_R0_Q8_0;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q2_K:
                             {
-                                nth0 = 2;
-                                nth1 = 32;
+                                nsg = N_SG_Q2_K;
+                                nr0 = N_R0_Q2_K;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q3_K:
                             {
-                                nth0 = 2;
-                                nth1 = 32;
+                                nsg = N_SG_Q3_K;
+                                nr0 = N_R0_Q3_K;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q4_K:
                             {
-                                nth0 = 4; //1;
-                                nth1 = 8; //32;
+                                nsg = N_SG_Q4_K;
+                                nr0 = N_R0_Q4_K;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q5_K:
                             {
-                                nth0 = 2;
-                                nth1 = 32;
+                                nsg = N_SG_Q5_K;
+                                nr0 = N_R0_Q5_K;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
                             } break;
                         case GGML_TYPE_Q6_K:
                             {
-                                nth0 = 2;
-                                nth1 = 32;
+                                nsg = N_SG_Q6_K;
+                                nr0 = N_R0_Q6_K;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
                             } break;
                         case GGML_TYPE_IQ2_XXS:
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ2_XXS;
+                                nr0 = N_R0_IQ2_XXS;
+                                smem = 256*8+128;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline;
                             } break;
                         case GGML_TYPE_IQ2_XS:
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ2_XS;
+                                nr0 = N_R0_IQ2_XS;
+                                smem = 512*8+128;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
                             } break;
                         case GGML_TYPE_IQ3_XXS:
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ3_XXS;
+                                nr0 = N_R0_IQ3_XXS;
+                                smem = 256*4+128;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline;
                             } break;
                         case GGML_TYPE_IQ3_S:
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ3_S;
+                                nr0 = N_R0_IQ3_S;
+                                smem = 512*4;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline;
                             } break;
                         case GGML_TYPE_IQ2_S:
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ2_S;
+                                nr0 = N_R0_IQ2_S;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline;
                             } break;
                         case GGML_TYPE_IQ1_S:
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ1_S;
+                                nr0 = N_R0_IQ1_S;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline;
                             } break;
                         case GGML_TYPE_IQ1_M:
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ1_M;
+                                nr0 = N_R0_IQ1_M;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline;
                             } break;
                         case GGML_TYPE_IQ4_NL:
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ4_NL;
+                                nr0 = N_R0_IQ4_NL;
+                                smem = 32*sizeof(float);
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline;
                             } break;
                         case GGML_TYPE_IQ4_XS:
                             {
-                                nth0 = 4;
-                                nth1 = 16;
+                                nsg = N_SG_IQ4_XS;
+                                nr0 = N_R0_IQ4_XS;
+                                smem = 32*sizeof(float);
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline;
                             } break;
                         default:
@@ -3052,7 +3039,7 @@ static void ggml_metal_encode_node(
                     };
 
                     if (ggml_is_quantized(src0t)) {
-                        GGML_ASSERT(ne00 >= nth0*nth1);
+                        GGML_ASSERT(ne00 >= nsg*nr0);
                     }
 
                     ggml_metal_kargs_mul_mv_id args = {
@@ -3085,43 +3072,12 @@ static void ggml_metal_encode_node(
                     [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
 
                     const int64_t _ne1 = 1;
-                    const int tgz = dst_rows;
+                    const int64_t ne123 = dst_rows;
 
-                    if (src0t == GGML_TYPE_Q4_0  || src0t == GGML_TYPE_Q4_1  || src0t == GGML_TYPE_Q5_0 ||
-                            src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K ||
-                            src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
-                        const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
-                        [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
-                        const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
-                        [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
-                        const int mem_size = 32*sizeof(float);
-                        [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_Q4_K) {
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_Q3_K) {
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_Q5_K) {
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    }
-                    else if (src0t == GGML_TYPE_Q6_K) {
-                        [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
-                    } else {
-                        const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1
-                        [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+                    if (smem > 0) {
+                        [encoder setThreadgroupMemoryLength:smem atIndex:0];
                     }
+                    [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
                 }
             } break;
         case GGML_OP_GET_ROWS:
index 3cef81b797197a2160c43175654c813d9ac3288c..38f03efbae27355249edf474314c1e3044edcff8 100644 (file)
@@ -1439,7 +1439,7 @@ kernel void kernel_rwkv_wkv7_f32(
 
         float4 sa_vec(0.0);
 
-        for (int j = 0; j < head_size; j += 4) {
+        for (uint j = 0; j < head_size; j += 4) {
             float4 a_vec = float4(_a[j], _a[j+1], _a[j+2], _a[j+3]);
             float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]);
             sa_vec += a_vec * s_vec;
@@ -1853,14 +1853,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
     return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
 }
 
-// putting them in the kernel cause a significant performance penalty
-#define N_DST 4        // each SIMD group works on 4 rows
-#define N_SIMDGROUP 2  // number of SIMD groups in a thread group
-//Note: This is a template, but strictly speaking it only applies to
-//      quantizations where the block size is 32. It also does not
-//      guard against the number of rows not being divisible by
-//      N_DST, so this is another explicit assumption of the implementation.
-template<typename block_q_type, int nr, int nsg, int nw, typename args_t>
+template<typename block_q_type, int nr0, int nsg, int nw, typename args_t>
 void mul_vec_q_n_f32_impl(
         args_t args,
         device const char * src0,
@@ -1876,7 +1869,7 @@ void mul_vec_q_n_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * nsg + sgitg) * nr;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -1888,15 +1881,15 @@ void mul_vec_q_n_f32_impl(
     device const float        * y = (device const float        *) (src1 + offset1);
 
     // pointers to src0 rows
-    device const block_q_type * ax[nr];
-    for (int row = 0; row < nr; ++row) {
+    device const block_q_type * ax[nr0];
+    for (int row = 0; row < nr0; ++row) {
         const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
 
         ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
     }
 
     float yl[16]; // src1 vector cache
-    float sumf[nr] = {0.f};
+    float sumf[nr0] = {0.f};
 
     const short ix = (tiisg/2);
     const short il = (tiisg%2)*8;
@@ -1908,7 +1901,7 @@ void mul_vec_q_n_f32_impl(
         float sumy[2] = { 0.f, 0.f };
 
 #pragma unroll
-        for (int i = 0; i < 8; i += 2) {
+        for (short i = 0; i < 8; i += 2) {
             sumy[0]  += yb[i +  0] + yb[i +  1];
             yl[i + 0] = yb[i +  0];
             yl[i + 1] = yb[i +  1]/256.f;
@@ -1919,7 +1912,7 @@ void mul_vec_q_n_f32_impl(
         }
 
 #pragma unroll
-        for (int row = 0; row < nr; row++) {
+        for (short row = 0; row < nr0; row++) {
             sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il);
         }
 
@@ -1928,7 +1921,7 @@ void mul_vec_q_n_f32_impl(
 
     device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
 
-    for (int row = 0; row < nr; ++row) {
+    for (int row = 0; row < nr0; ++row) {
         const float tot = simd_sum(sumf[row]);
 
         if (tiisg == 0 && first_row + row < args.ne01) {
@@ -1945,7 +1938,7 @@ kernel void kernel_mul_mv_q4_0_f32(
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SG_Q4_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
 kernel void kernel_mul_mv_q4_1_f32(
@@ -1956,7 +1949,7 @@ kernel void kernel_mul_mv_q4_1_f32(
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-     mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+     mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, N_SG_Q4_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
 kernel void kernel_mul_mv_q5_0_f32(
@@ -1967,7 +1960,7 @@ kernel void kernel_mul_mv_q5_0_f32(
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SG_Q5_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
 kernel void kernel_mul_mv_q5_1_f32(
@@ -1978,12 +1971,12 @@ kernel void kernel_mul_mv_q5_1_f32(
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SG_Q5_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
 #define NB_Q8_0 8
 
-template<typename args_t>
+template<int nr0, int nsg, int nw, typename args_t>
 void kernel_mul_mv_q8_0_f32_impl(
         args_t args,
         device const char * src0,
@@ -1993,16 +1986,13 @@ void kernel_mul_mv_q8_0_f32_impl(
         uint3  tgpig,
         ushort tiisg,
         ushort sgitg) {
-    const int nr  = N_DST;
-    const int nsg = N_SIMDGROUP;
-    const int nw  = N_SIMDWIDTH;
-
     const int nb = args.ne00/QK8_0;
+
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0*nsg + sgitg)*nr;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -2014,15 +2004,15 @@ void kernel_mul_mv_q8_0_f32_impl(
     device const float      * y = (device const float      *) (src1 + offset1);
 
     // pointers to src0 rows
-    device const block_q8_0 * ax[nr];
-    for (int row = 0; row < nr; ++row) {
+    device const block_q8_0 * ax[nr0];
+    for (int row = 0; row < nr0; ++row) {
         const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
 
         ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
     }
 
     float yl[NB_Q8_0];
-    float sumf[nr] = { 0.f };
+    float sumf[nr0] = { 0.f };
 
     const short ix = tiisg/4;
     const short il = tiisg%4;
@@ -2035,7 +2025,7 @@ void kernel_mul_mv_q8_0_f32_impl(
             yl[i] = yb[i];
         }
 
-        for (int row = 0; row < nr; row++) {
+        for (short row = 0; row < nr0; row++) {
             device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0;
             float sumq = 0.f;
             for (short iq = 0; iq < NB_Q8_0; ++iq) {
@@ -2049,7 +2039,7 @@ void kernel_mul_mv_q8_0_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < nr; ++row) {
+    for (int row = 0; row < nr0; ++row) {
         const float tot = simd_sum(sumf[row]);
 
         if (tiisg == 0 && first_row + row < args.ne01) {
@@ -2067,7 +2057,7 @@ kernel void kernel_mul_mv_q8_0_f32(
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-    kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SG_Q8_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
 // mat-vec kernel processing in chunks of float4
@@ -2404,9 +2394,9 @@ void kernel_mul_mv_impl(
                 sumf += (T0) x[i] * (T1) y[i];
             }
 
-            float all_sum = simd_sum(sumf);
+            float sum_all = simd_sum(sumf);
             if (tiisg == 0) {
-                dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
+                dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
             }
         }
     } else {
@@ -2427,10 +2417,10 @@ void kernel_mul_mv_impl(
                 sumf += dot((float4) x4[i], (float4) y4[i]);
             }
 
-            float all_sum = simd_sum(sumf);
+            float sum_all = simd_sum(sumf);
             if (tiisg == 0) {
-                for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]);
-                dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
+                for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]);
+                dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
             }
         }
     }
@@ -2492,9 +2482,9 @@ kernel void kernel_mul_mv_1row(
         for (int i = tiisg; i < args.ne00; i += 32) {
             sumf += (float) x[i] * (float) y[i];
         }
-        float all_sum = simd_sum(sumf);
+        float sum_all = simd_sum(sumf);
         if (tiisg == 0) {
-            dst_f32[r0] = all_sum;
+            dst_f32[r0] = sum_all;
         }
     } else {
         device const T4     * x4 = (device const T4     *) x;
@@ -2504,11 +2494,11 @@ kernel void kernel_mul_mv_1row(
             sumf += dot((float4) x4[i], y4[i]);
         }
 
-        float all_sum = simd_sum(sumf);
+        float sum_all = simd_sum(sumf);
 
         if (tiisg == 0) {
-            for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]);
-            dst_f32[r0] = all_sum;
+            for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]);
+            dst_f32[r0] = sum_all;
         }
     }
 }
@@ -2553,9 +2543,9 @@ kernel void kernel_mul_mv_l4(
             sumf += dot((float4) x4[i], y4[i]);
         }
 
-        float all_sum = simd_sum(sumf);
+        float sum_all = simd_sum(sumf);
         if (tiisg == 0) {
-            dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
+            dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
         }
     }
 }
@@ -4321,7 +4311,7 @@ kernel void kernel_cpy_f32_iq4_nl(
         float amax = 0.0f; // absolute max
         float max  = 0.0f;
 
-        for (int j = 0; j < QK4_0; j++) {
+        for (int j = 0; j < QK4_NL; j++) {
             const float v = src[j];
             if (amax < fabs(v)) {
                 amax = fabs(v);
@@ -4429,7 +4419,7 @@ kernel void kernel_concat(
     }
 }
 
-template<typename args_t>
+template<int nr0, int nsg, int nw, typename args_t>
 void kernel_mul_mv_q2_K_f32_impl(
         args_t args,
         device const char * src0,
@@ -4445,7 +4435,7 @@ void kernel_mul_mv_q2_K_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -4457,20 +4447,19 @@ void kernel_mul_mv_q2_K_f32_impl(
     device const float      * y = (device const float      *) (src1 + offset1);
 
     float yl[32];
-    float sumf[N_DST]={0.f}, all_sum;
+    float sumf[nr0]={0.f};
 
-    const int ix = tiisg/8;  // 0...3
-    const int it = tiisg%8;  // 0...7
-    const int iq = it/4;     // 0 or 1
-    const int ir = it%4;     // 0...3
-    const int is = (8*ir)/16;// 0 or 1
+    const short ix = tiisg/8;  // 0...3
+    const short it = tiisg%8;  // 0...7
+    const short iq = it/4;     // 0 or 1
+    const short ir = it%4;     // 0...3
+    const short is = (8*ir)/16;// 0 or 1
 
     device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
 
     for (int ib = ix; ib < nb; ib += 4) {
-
         float4 sumy = {0.f, 0.f, 0.f, 0.f};
-        for (int i = 0; i < 8; ++i) {
+        for (short i = 0; i < 8; ++i) {
             yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
             yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
             yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
@@ -4481,7 +4470,7 @@ void kernel_mul_mv_q2_K_f32_impl(
         device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
         device const half     * dh = &x[ib].d;
 
-        for (int row = 0; row < N_DST; row++) {
+        for (short row = 0; row < nr0; row++) {
             float4 acc1 = {0.f, 0.f, 0.f, 0.f};
             float4 acc2 = {0.f, 0.f, 0.f, 0.f};
             for (int i = 0; i < 8; i += 2) {
@@ -4512,10 +4501,10 @@ void kernel_mul_mv_q2_K_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
-        all_sum = simd_sum(sumf[row]);
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst_f32[first_row + row] = all_sum;
+            dst_f32[first_row + row] = sum_all;
         }
     }
 }
@@ -4530,10 +4519,10 @@ kernel void kernel_mul_mv_q2_K_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q2_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q2_K_f32_impl<N_R0_Q2_K, N_SG_Q2_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
-template<typename args_t>
+template<int nr0, int nsg, int nw, typename args_t>
 void kernel_mul_mv_q3_K_f32_impl(
         args_t args,
         device const char * src0,
@@ -4550,7 +4539,7 @@ void kernel_mul_mv_q3_K_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -4566,13 +4555,12 @@ void kernel_mul_mv_q3_K_f32_impl(
     //const uint16_t kmask1 = 0x3030;
     //const uint16_t kmask2 = 0x0f0f;
 
-    const int tid = tiisg/4;
-    const int ix  = tiisg%4;
-    const int ip  = tid/4;          // 0 or 1
-    const int il  = 2*((tid%4)/2);  // 0 or 2
-    const int ir  = tid%2;
-    const int n   = 8;
-    const int l0  = n*ir;
+    const short tid = tiisg/4;
+    const short ix  = tiisg%4;
+    const short ip  = tid/4;          // 0 or 1
+    const short il  = 2*((tid%4)/2);  // 0 or 2
+    const short ir  = tid%2;
+    const short l0  = 8*ir;
 
     // One would think that the Metal compiler would figure out that ip and il can only have
     // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
@@ -4597,8 +4585,8 @@ void kernel_mul_mv_q3_K_f32_impl(
     const uint16_t s_shift1 = 4*ip;
     const uint16_t s_shift2 = s_shift1 + il;
 
-    const int q_offset = 32*ip + l0;
-    const int y_offset = 128*ip + 32*il + l0;
+    const short q_offset = 32*ip + l0;
+    const short y_offset = 128*ip + 32*il + l0;
 
     device const float * y1 = yy + ix*QK_K + y_offset;
 
@@ -4606,10 +4594,11 @@ void kernel_mul_mv_q3_K_f32_impl(
     thread uint16_t * scales16 = (thread uint16_t *)&scales32;
     thread const int8_t * scales = (thread const int8_t *)&scales32;
 
-    float sumf1[2] = {0.f};
-    float sumf2[2] = {0.f};
+    float sumf1[nr0] = {0.f};
+    float sumf2[nr0] = {0.f};
+
     for (int i = ix; i < nb; i += 4) {
-        for (int l = 0; l < 8; ++l) {
+        for (short l = 0; l < 8; ++l) {
             yl[l+ 0] = y1[l+ 0];
             yl[l+ 8] = y1[l+16];
             yl[l+16] = y1[l+32];
@@ -4621,7 +4610,7 @@ void kernel_mul_mv_q3_K_f32_impl(
         device const uint16_t * a = (device const uint16_t *)(x[i].scales);
         device const half * dh = &x[i].d;
 
-        for (int row = 0; row < 2; ++row) {
+        for (short row = 0; row < nr0; ++row) {
             const float d_all = (float)dh[0];
 
             scales16[0] = a[4];
@@ -4632,7 +4621,7 @@ void kernel_mul_mv_q3_K_f32_impl(
             scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
 
             float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
-            for (int l = 0; l < n; l += 2) {
+            for (short l = 0; l < 8; l += 2) {
                 const int32_t qs = q[l/2];
                 s1 += yl[l+0] * (qs & qm[il/2][0]);
                 s2 += yl[l+1] * (qs & qm[il/2][1]);
@@ -4647,7 +4636,7 @@ void kernel_mul_mv_q3_K_f32_impl(
             sumf2[row] += d2 * (scales[2] - 32);
 
             s1 = s2 = s3 = s4 = s5 = s6 = 0;
-            for (int l = 0; l < n; l += 2) {
+            for (short l = 0; l < 8; l += 2) {
                 const int32_t qs = q[l/2+8];
                 s1 += yl[l+8] * (qs & qm[il/2][0]);
                 s2 += yl[l+9] * (qs & qm[il/2][1]);
@@ -4670,7 +4659,7 @@ void kernel_mul_mv_q3_K_f32_impl(
         y1 += 4 * QK_K;
     }
 
-    for (int row = 0; row < 2; ++row) {
+    for (int row = 0; row < nr0; ++row) {
         const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
         sumf1[row] = simd_sum(sumf);
     }
@@ -4678,7 +4667,7 @@ void kernel_mul_mv_q3_K_f32_impl(
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
     if (tiisg == 0) {
-        for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
+        for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
             dst_f32[first_row + row] = sumf1[row];
         }
     }
@@ -4694,10 +4683,10 @@ kernel void kernel_mul_mv_q3_K_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q3_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q3_K_f32_impl<N_R0_Q3_K, N_SG_Q3_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
-template<typename args_t>
+template<int nr0, int nsg, int nw, typename args_t>
 void kernel_mul_mv_q4_K_f32_impl(
         args_t args,
         device const char * src0,
@@ -4707,22 +4696,22 @@ void kernel_mul_mv_q4_K_f32_impl(
         uint3  tgpig,
         ushort tiisg,
         ushort sgitg) {
-
     const uint16_t kmask1 = 0x3f3f;
     const uint16_t kmask2 = 0x0f0f;
     const uint16_t kmask3 = 0xc0c0;
 
-    const int ix = tiisg/8;  // 0...3
-    const int it = tiisg%8;  // 0...7
-    const int iq = it/4;     // 0 or 1
-    const int ir = it%4;     // 0...3
+    const short ix = tiisg/8;  // 0...3
+    const short it = tiisg%8;  // 0...7
+    const short iq = it/4;     // 0 or 1
+    const short ir = it%4;     // 0...3
 
     const int nb = args.ne00/QK_K;
+
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
-    //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
-    const int first_row = r0 * N_DST;
+
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -4735,7 +4724,8 @@ void kernel_mul_mv_q4_K_f32_impl(
 
     float yl[16];
     float yh[16];
-    float sumf[N_DST]={0.f}, all_sum;
+
+    float sumf[nr0]={0.f};
 
     device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
 
@@ -4744,7 +4734,8 @@ void kernel_mul_mv_q4_K_f32_impl(
 
     for (int ib = ix; ib < nb; ib += 4) {
         float4 sumy = {0.f, 0.f, 0.f, 0.f};
-        for (int i = 0; i < 8; ++i) {
+
+        for (short i = 0; i < 8; ++i) {
             yl[i+0] = y4[i+  0]; sumy[0] += yl[i+0];
             yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8];
             yh[i+0] = y4[i+128]; sumy[2] += yh[i+0];
@@ -4755,7 +4746,7 @@ void kernel_mul_mv_q4_K_f32_impl(
         device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
         device const half     * dh = &x[ib].d;
 
-        for (int row = 0; row < N_DST; row++) {
+        for (short row = 0; row < nr0; row++) {
             sc16[0] = sc[0] & kmask1;
             sc16[1] = sc[2] & kmask1;
             sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
@@ -4765,19 +4756,21 @@ void kernel_mul_mv_q4_K_f32_impl(
 
             float4 acc1 = {0.f, 0.f, 0.f, 0.f};
             float4 acc2 = {0.f, 0.f, 0.f, 0.f};
-            for (int i = 0; i < 8; i += 2) {
-                acc1[0] += yl[i+0] * (q1[i/2] & 0x000F);
-                acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00);
-                acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0);
-                acc1[3] += yl[i+9] * (q1[i/2] & 0xF000);
-                acc2[0] += yh[i+0] * (q2[i/2] & 0x000F);
-                acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00);
-                acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0);
-                acc2[3] += yh[i+9] * (q2[i/2] & 0xF000);
+
+            for (short i = 0; i < 4; ++i) {
+                acc1[0] += yl[2*i + 0] * (q1[i] & 0x000F);
+                acc1[1] += yl[2*i + 1] * (q1[i] & 0x0F00);
+                acc1[2] += yl[2*i + 8] * (q1[i] & 0x00F0);
+                acc1[3] += yl[2*i + 9] * (q1[i] & 0xF000);
+                acc2[0] += yh[2*i + 0] * (q2[i] & 0x000F);
+                acc2[1] += yh[2*i + 1] * (q2[i] & 0x0F00);
+                acc2[2] += yh[2*i + 8] * (q2[i] & 0x00F0);
+                acc2[3] += yh[2*i + 9] * (q2[i] & 0xF000);
             }
 
             float dall = dh[0];
             float dmin = dh[1];
+
             sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
                                  (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
                                  (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
@@ -4794,10 +4787,10 @@ void kernel_mul_mv_q4_K_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
-        all_sum = simd_sum(sumf[row]);
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst_f32[first_row + row] = all_sum;
+            dst_f32[first_row + row] = sum_all;
         }
     }
 }
@@ -4812,10 +4805,10 @@ kernel void kernel_mul_mv_q4_K_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q4_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q4_K_f32_impl<N_R0_Q4_K, N_SG_Q4_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
-template<typename args_t>
+template<int nr0, int nsg, int nw, typename args_t>
 void kernel_mul_mv_q5_K_f32_impl(
         args_t args,
         device const char * src0,
@@ -4832,7 +4825,7 @@ void kernel_mul_mv_q5_K_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -4843,7 +4836,7 @@ void kernel_mul_mv_q5_K_f32_impl(
     device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0);
     device const float     * yy = (device const float      *) (src1 + offset1);
 
-    float sumf[2]={0.f};
+    float sumf[nr0]={0.f};
 
     float yl[16], yh[16];
 
@@ -4851,15 +4844,14 @@ void kernel_mul_mv_q5_K_f32_impl(
     const uint16_t kmask2 = 0x0f0f;
     const uint16_t kmask3 = 0xc0c0;
 
-    const int tid = tiisg/4;
-    const int ix  = tiisg%4;
-    const int iq  = tid/4;
-    const int ir  = tid%4;
-    const int n   = 8;
+    const short tid = tiisg/4;
+    const short ix  = tiisg%4;
+    const short iq  = tid/4;
+    const short ir  = tid%4;
 
-    const int l0 = n*ir;
-    const int q_offset = 32*iq + l0;
-    const int y_offset = 64*iq + l0;
+    const short l0 = 8*ir;
+    const short q_offset = 32*iq + l0;
+    const short y_offset = 64*iq + l0;
 
     const uint8_t hm1 = 1u << (2*iq);
     const uint8_t hm2 = hm1 << 1;
@@ -4879,14 +4871,14 @@ void kernel_mul_mv_q5_K_f32_impl(
 
         device const float * y2 = y1 + 128;
         float4 sumy = {0.f, 0.f, 0.f, 0.f};
-        for (int l = 0; l < 8; ++l) {
+        for (short l = 0; l < 8; ++l) {
             yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
             yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
             yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
             yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
         }
 
-        for (int row = 0; row < 2; ++row) {
+        for (short row = 0; row < nr0; ++row) {
             device const uint8_t * q2 = q1 + 64;
 
             sc16[0] = a[0] & kmask1;
@@ -4896,7 +4888,7 @@ void kernel_mul_mv_q5_K_f32_impl(
 
             float4 acc1 = {0.f};
             float4 acc2 = {0.f};
-            for (int l = 0; l < n; ++l) {
+            for (short l = 0; l < 8; ++l) {
                 uint8_t h = qh[l];
                 acc1[0] += yl[l+0] * (q1[l] & 0x0F);
                 acc1[1] += yl[l+8] * (q1[l] & 0xF0);
@@ -4926,7 +4918,7 @@ void kernel_mul_mv_q5_K_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
         const float tot = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = tot;
@@ -4944,10 +4936,10 @@ kernel void kernel_mul_mv_q5_K_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q5_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q5_K_f32_impl<N_R0_Q5_K, N_SG_Q5_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
-template <typename args_t>
+template<int nr0, int nsg, int nw, typename args_t>
 void kernel_mul_mv_q6_K_f32_impl(
         args_t args,
         device const char * src0,
@@ -4969,62 +4961,77 @@ void kernel_mul_mv_q6_K_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int row = 2*r0 + sgitg;
-
-    if (row >= args.ne0) {
-        return;
-    }
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
 
-    const uint64_t offset0 = row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
-    const uint64_t offset1 =  r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
 
     device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0);
     device const float     * yy = (device const float      *) (src1 + offset1);
 
-    float sumf = 0;
+    float sumf[nr0] = { 0.f };
+
+    float yl[16];
 
-    const int tid  = tiisg/2;
-    const int ix   = tiisg%2;
-    const int ip   = tid/8;         // 0 or 1
-    const int il   = tid%8;
-    const int n    = 4;
-    const int l0   = n*il;
-    const int is   = 8*ip + l0/16;
+    const short tid = tiisg/2;
+    const short ix  = tiisg%2;
+    const short ip  = tid/8;         // 0 or 1
+    const short il  = tid%8;
+    const short l0  = 4*il;
+    const short is  = 8*ip + l0/16;
 
-    const int y_offset = 128*ip + l0;
-    const int q_offset_l = 64*ip + l0;
-    const int q_offset_h = 32*ip + l0;
+    const short y_offset   = 128*ip + l0;
+    const short q_offset_l =  64*ip + l0;
+    const short q_offset_h =  32*ip + l0;
 
     for (int i = ix; i < nb; i += 2) {
         device const uint8_t * q1 = x[i].ql + q_offset_l;
         device const uint8_t * q2 = q1 + 32;
         device const uint8_t * qh = x[i].qh + q_offset_h;
         device const int8_t  * sc = x[i].scales + is;
+        device const half    * dh = &x[i].d;
 
         device const float * y = yy + i * QK_K + y_offset;
 
-        const float dall = x[i].d;
-
-        float4 sums = {0.f, 0.f, 0.f, 0.f};
-        for (int l = 0; l < n; ++l) {
-            sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
-            sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
-            sums[2] += y[l+64] * ((int8_t)((q1[l]  >> 4) | ((qh[l] & kmask3) << 0)) - 32);
-            sums[3] += y[l+96] * ((int8_t)((q2[l]  >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
+        for (short l = 0; l < 4; ++l) {
+            yl[4*l + 0] = y[l +  0];
+            yl[4*l + 1] = y[l + 32];
+            yl[4*l + 2] = y[l + 64];
+            yl[4*l + 3] = y[l + 96];
         }
 
-        sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
+        for (short row = 0; row < nr0; ++row) {
+            const float dall = dh[0];
+
+            float4 sums = {0.f, 0.f, 0.f, 0.f};
 
+            for (short l = 0; l < 4; ++l) {
+                sums[0] += yl[4*l + 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
+                sums[1] += yl[4*l + 1] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
+                sums[2] += yl[4*l + 2] * ((int8_t)((q1[l]  >> 4) | ((qh[l] & kmask3) << 0)) - 32);
+                sums[3] += yl[4*l + 3] * ((int8_t)((q2[l]  >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
+            }
+
+            sumf[row] += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
+
+            q1 += args.nb01;
+            q2 += args.nb01;
+            qh += args.nb01;
+            sc += args.nb01;
+            dh += args.nb01/2;
+        }
     }
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    const float tot = simd_sum(sumf);
-    if (tiisg == 0) {
-        dst_f32[row] = tot;
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst_f32[first_row + row] = sum_all;
+        }
     }
 }
 
@@ -5038,12 +5045,12 @@ kernel void kernel_mul_mv_q6_K_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q6_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q6_K_f32_impl<N_R0_Q6_K, N_SG_Q6_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
 // ======================= "True" 2-bit
 
-template<typename args_t>
+template<int nr0, int nsg, int nw, typename args_t>
 void kernel_mul_mv_iq2_xxs_f32_impl(
         args_t args,
         device const char * src0,
@@ -5059,7 +5066,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -5071,7 +5078,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
     device const float         * y = (device const float         *) (src1 + offset1);
 
     float yl[32];
-    float sumf[N_DST]={0.f}, all_sum;
+    float sumf[nr0]={0.f};
 
     const int nb32 = nb * (QK_K / 32);
 
@@ -5092,8 +5099,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
     device const float * y4 = y + 32 * ix;
 
     for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
-        for (int i = 0; i < 32; ++i) {
+        for (short i = 0; i < 32; ++i) {
             yl[i] = y4[i];
         }
 
@@ -5104,18 +5110,17 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
         device const uint16_t * q2 = xr->qs + 4 * ib;
         device const half * dh = &xr->d;
 
-        for (int row = 0; row < N_DST; row++) {
-
+        for (short row = 0; row < nr0; row++) {
             const float db = dh[0];
             device const uint8_t * aux8 = (device const uint8_t *)q2;
             const uint32_t aux32 = q2[2] | (q2[3] << 16);
             const float d = db * (0.5f + (aux32 >> 28));
 
             float sum = 0;
-            for (int l = 0; l < 4; ++l) {
+            for (short l = 0; l < 4; ++l) {
                 const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]);
                 const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
-                for (int j = 0; j < 8; ++j) {
+                for (short j = 0; j < 8; ++j) {
                     sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
                 }
             }
@@ -5130,10 +5135,10 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
-        all_sum = simd_sum(sumf[row]);
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst_f32[first_row + row] = all_sum * 0.25f;
+            dst_f32[first_row + row] = sum_all * 0.25f;
         }
     }
 }
@@ -5148,10 +5153,10 @@ kernel void kernel_mul_mv_iq2_xxs_f32(
         uint3  tgpig[[threadgroup_position_in_grid]],
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-    kernel_mul_mv_iq2_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, N_SG_IQ2_XXS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-template<typename args_t>
+template<int nr0, int nsg, int nw, typename args_t>
 void kernel_mul_mv_iq2_xs_f32_impl(
         args_t args,
         device const char * src0,
@@ -5167,7 +5172,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -5179,7 +5184,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
     device const float        * y = (device const float        *) (src1 + offset1);
 
     float yl[32];
-    float sumf[N_DST]={0.f}, all_sum;
+    float sumf[nr0]={0.f};
 
     const int nb32 = nb * (QK_K / 32);
 
@@ -5200,8 +5205,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
     device const float * y4 = y + 32 * ix;
 
     for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
-        for (int i = 0; i < 32; ++i) {
+        for (short i = 0; i < 32; ++i) {
             yl[i] = y4[i];
         }
 
@@ -5213,8 +5217,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
         device const uint8_t  * sc = xr->scales + ib;
         device const half * dh = &xr->d;
 
-        for (int row = 0; row < N_DST; row++) {
-
+        for (short row = 0; row < nr0; row++) {
             const float db = dh[0];
             const uint8_t ls1 = sc[0] & 0xf;
             const uint8_t ls2 = sc[0] >>  4;
@@ -5222,17 +5225,17 @@ void kernel_mul_mv_iq2_xs_f32_impl(
             const float d2 = db * (0.5f + ls2);
 
             float sum1 = 0, sum2 = 0;
-            for (int l = 0; l < 2; ++l) {
+            for (short l = 0; l < 2; ++l) {
                 const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
                 const uint8_t signs = ssigns[(q2[l] >> 9)];
-                for (int j = 0; j < 8; ++j) {
+                for (short j = 0; j < 8; ++j) {
                     sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
                 }
             }
-            for (int l = 2; l < 4; ++l) {
+            for (short l = 2; l < 4; ++l) {
                 const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
                 const uint8_t signs = ssigns[(q2[l] >> 9)];
-                for (int j = 0; j < 8; ++j) {
+                for (short j = 0; j < 8; ++j) {
                     sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
                 }
             }
@@ -5248,10 +5251,10 @@ void kernel_mul_mv_iq2_xs_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
-        all_sum = simd_sum(sumf[row]);
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst_f32[first_row + row] = all_sum * 0.25f;
+            dst_f32[first_row + row] = sum_all * 0.25f;
         }
     }
 }
@@ -5267,10 +5270,10 @@ kernel void kernel_mul_mv_iq2_xs_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq2_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq2_xs_f32_impl<N_R0_IQ2_XS, N_SG_IQ2_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-template <typename args_t>
+template<int nr0, int nsg, int nw, typename args_t>
 void kernel_mul_mv_iq3_xxs_f32_impl(
         args_t args,
         device const char * src0,
@@ -5286,7 +5289,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -5298,7 +5301,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
     device const float         * y = (device const float         *) (src1 + offset1);
 
     float yl[32];
-    float sumf[N_DST]={0.f}, all_sum;
+    float sumf[nr0]={0.f};
 
     const int nb32 = nb * (QK_K / 32);
 
@@ -5319,7 +5322,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
     device const float * y4 = y + 32 * ix;
 
     for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-        for (int i = 0; i < 32; ++i) {
+        for (short i = 0; i < 32; ++i) {
             yl[i] = y4[i];
         }
 
@@ -5331,17 +5334,17 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
         device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib;
         device const half * dh = &xr->d;
 
-        for (int row = 0; row < N_DST; row++) {
+        for (short row = 0; row < nr0; row++) {
             const float db = dh[0];
             const uint32_t aux32 = gas[0] | (gas[1] << 16);
             const float d = db * (0.5f + (aux32 >> 28));
 
             float2 sum = {0};
-            for (int l = 0; l < 4; ++l) {
+            for (short l = 0; l < 4; ++l) {
                 const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]);
                 const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]);
                 const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
-                for (int j = 0; j < 4; ++j) {
+                for (short j = 0; j < 4; ++j) {
                     sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
                     sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
                 }
@@ -5358,10 +5361,10 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
-        all_sum = simd_sum(sumf[row]);
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst_f32[first_row + row] = all_sum * 0.5f;
+            dst_f32[first_row + row] = sum_all * 0.5f;
         }
     }
 }
@@ -5377,10 +5380,10 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq3_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, N_SG_IQ3_XXS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-template<typename args_t>
+template<int nr0, int nsg, int nw, typename args_t>
 void kernel_mul_mv_iq3_s_f32_impl(
         args_t args,
         device const char * src0,
@@ -5396,7 +5399,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -5408,7 +5411,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
     device const float       * y = (device const float       *) (src1 + offset1);
 
     float yl[32];
-    float sumf[N_DST]={0.f}, all_sum;
+    float sumf[nr0]={0.f};
 
     const int nb32 = nb * (QK_K / 32);
 
@@ -5425,8 +5428,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
     device const float * y4 = y + 32 * ix;
 
     for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
-        for (int i = 0; i < 32; ++i) {
+        for (short i = 0; i < 32; ++i) {
             yl[i] = y4[i];
         }
 
@@ -5440,18 +5442,17 @@ void kernel_mul_mv_iq3_s_f32_impl(
         device const uint8_t * signs = xr->signs + 4 * ib;
         device const half * dh = &xr->d;
 
-        for (int row = 0; row < N_DST; row++) {
-
+        for (short row = 0; row < nr0; row++) {
             const float db = dh[0];
             const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf));
 
             float2 sum = {0};
-            for (int l = 0; l < 4; ++l) {
+            for (short l = 0; l < 4; ++l) {
                 const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues;
                 const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues;
                 const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);
                 const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);
-                for (int j = 0; j < 4; ++j) {
+                for (short j = 0; j < 4; ++j) {
                     sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
                     sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
                 }
@@ -5470,10 +5471,10 @@ void kernel_mul_mv_iq3_s_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
-        all_sum = simd_sum(sumf[row]);
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst_f32[first_row + row] = all_sum;
+            dst_f32[first_row + row] = sum_all;
         }
     }
 }
@@ -5489,10 +5490,10 @@ kernel void kernel_mul_mv_iq3_s_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq3_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq3_s_f32_impl<N_R0_IQ3_S, N_SG_IQ3_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-template <typename args_t>
+template<int nr0, int nsg, int nw, typename args_t>
 void kernel_mul_mv_iq2_s_f32_impl(
         args_t args,
         device const char * src0,
@@ -5508,7 +5509,7 @@ void kernel_mul_mv_iq2_s_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -5520,7 +5521,7 @@ void kernel_mul_mv_iq2_s_f32_impl(
     device const float       * y = (device const float       *) (src1 + offset1);
 
     float yl[32];
-    float sumf[N_DST]={0.f}, all_sum;
+    float sumf[nr0]={0.f};
 
     const int nb32 = nb * (QK_K / 32);
 
@@ -5532,13 +5533,12 @@ void kernel_mul_mv_iq2_s_f32_impl(
     //    threadgroup_barrier(mem_flags::mem_threadgroup);
     //}
 
-    const int ix = tiisg;
+    const short ix = tiisg;
 
     device const float * y4 = y + 32 * ix;
 
     for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
-        for (int i = 0; i < 32; ++i) {
+        for (short i = 0; i < 32; ++i) {
             yl[i] = y4[i];
         }
 
@@ -5552,19 +5552,18 @@ void kernel_mul_mv_iq2_s_f32_impl(
         device const uint8_t * signs = qs + QK_K/8;
         device const half * dh = &xr->d;
 
-        for (int row = 0; row < N_DST; row++) {
-
+        for (short row = 0; row < nr0; row++) {
             const float db = dh[0];
             const float d1 = db * (0.5f + (sc[0] & 0xf));
             const float d2 = db * (0.5f + (sc[0] >>  4));
 
             float2 sum = {0};
-            for (int l = 0; l < 2; ++l) {
+            for (short l = 0; l < 2; ++l) {
                 //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
                 //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
                 constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
                 constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
-                for (int j = 0; j < 8; ++j) {
+                for (short j = 0; j < 8; ++j) {
                     sum[0] += yl[8*l + j +  0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]);
                     sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]);
                 }
@@ -5583,10 +5582,10 @@ void kernel_mul_mv_iq2_s_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
-        all_sum = simd_sum(sumf[row]);
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst_f32[first_row + row] = all_sum * 0.25f;
+            dst_f32[first_row + row] = sum_all * 0.25f;
         }
     }
 }
@@ -5602,10 +5601,10 @@ kernel void kernel_mul_mv_iq2_s_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq2_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq2_s_f32_impl<N_R0_IQ2_S, N_SG_IQ2_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
-template<typename args_t>
+template<int nr0, int nsg, int nw, typename args_t>
 void kernel_mul_mv_iq1_s_f32_impl(
         args_t args,
         device const char * src0,
@@ -5621,7 +5620,7 @@ void kernel_mul_mv_iq1_s_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -5633,18 +5632,17 @@ void kernel_mul_mv_iq1_s_f32_impl(
     device const float       * y = (device const float       *) (src1 + offset1);
 
     float yl[32];
-    float sumf[N_DST]={0.f}, all_sum;
+    float sumf[nr0]={0.f};
 
     const int nb32 = nb * (QK_K / 32);
 
-    const int ix = tiisg;
+    const short ix = tiisg;
 
     device const float * y4 = y + 32 * ix;
 
     for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
         float sumy = 0;
-        for (int i = 0; i < 32; ++i) {
+        for (short i = 0; i < 32; ++i) {
             yl[i] = y4[i];
             sumy += yl[i];
         }
@@ -5657,15 +5655,14 @@ void kernel_mul_mv_iq1_s_f32_impl(
         device const uint16_t * qh = xr->qh + ib;
         device const half     * dh = &xr->d;
 
-        for (int row = 0; row < N_DST; row++) {
-
+        for (short row = 0; row < nr0; row++) {
             constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
             constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700)));
             constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700)));
             constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700)));
 
             float sum = 0;
-            for (int j = 0; j < 4; ++j) {
+            for (short j = 0; j < 4; ++j) {
                 sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
                      + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4)
                      + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
@@ -5683,15 +5680,28 @@ void kernel_mul_mv_iq1_s_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
-        all_sum = simd_sum(sumf[row]);
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst_f32[first_row + row] = all_sum;
+            dst_f32[first_row + row] = sum_all;
         }
     }
 }
 
-template <typename args_t>
+[[host_name("kernel_mul_mv_iq1_s_f32")]]
+kernel void kernel_mul_mv_iq1_s_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq1_s_f32_impl<N_R0_IQ1_S, N_SG_IQ1_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+template<int nr0, int nsg, int nw, typename args_t>
 void kernel_mul_mv_iq1_m_f32_impl(
         args_t args,
         device const char * src0,
@@ -5703,11 +5713,12 @@ void kernel_mul_mv_iq1_m_f32_impl(
         ushort sgitg) {
 
     const int nb = args.ne00/QK_K;
+
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -5719,20 +5730,19 @@ void kernel_mul_mv_iq1_m_f32_impl(
     device const float       * y = (device const float       *) (src1 + offset1);
 
     float yl[32];
-    float sumf[N_DST]={0.f}, all_sum;
+    float sumf[nr0]={0.f};
 
     const int nb32 = nb * (QK_K / 32);
 
-    const int ix = tiisg;
+    const short ix = tiisg;
 
     device const float * y4 = y + 32 * ix;
 
     iq1m_scale_t scale;
 
     for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
         float4 sumy = {0.f};
-        for (int i = 0; i < 8; ++i) {
+        for (short i = 0; i < 8; ++i) {
             yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
             yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8];
             yl[i+16] = y4[i+16]; sumy[2] += yl[i+16];
@@ -5747,7 +5757,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
         device const uint8_t  * qh = xr->qh + 2 * ib;
         device const uint16_t * sc = (device const uint16_t *)xr->scales;
 
-        for (int row = 0; row < N_DST; row++) {
+        for (short row = 0; row < nr0; row++) {
             scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
 
             constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
@@ -5756,7 +5766,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
             constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700)));
 
             float2 sum = {0.f};
-            for (int j = 0; j < 4; ++j) {
+            for (short j = 0; j < 4; ++j) {
                 sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
                         + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4);
                 sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
@@ -5778,15 +5788,28 @@ void kernel_mul_mv_iq1_m_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
-        all_sum = simd_sum(sumf[row]);
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst_f32[first_row + row] = all_sum;
+            dst_f32[first_row + row] = sum_all;
         }
     }
 }
 
-template<typename args_t>
+[[host_name("kernel_mul_mv_iq1_m_f32")]]
+kernel void kernel_mul_mv_iq1_m_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq1_m_f32_impl<N_R0_IQ1_M, N_SG_IQ1_M, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
+}
+
+template<int nr0, int nsg, int nw, typename args_t>
 void kernel_mul_mv_iq4_nl_f32_impl(
         args_t args,
         device const char * src0,
@@ -5799,10 +5822,12 @@ void kernel_mul_mv_iq4_nl_f32_impl(
 
     threadgroup float * shmem_f32 = (threadgroup float *) shmem;
     const int nb = args.ne00/QK4_NL;
+
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
-    const int first_row = (r0 * 2 + sgitg) * 2;
+
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -5813,14 +5838,14 @@ void kernel_mul_mv_iq4_nl_f32_impl(
     device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
     device const float        * y = (device const float        *) (src1 + offset1);
 
-    const int ix = tiisg/2;  // 0...15
-    const int it = tiisg%2;  // 0 or 1
+    const short ix = tiisg/2;  // 0...15
+    const short it = tiisg%2;  // 0 or 1
 
     shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
     float4 yl[4];
-    float sumf[2]={0.f}, all_sum;
+    float sumf[nr0]={0.f};
 
     device const float * yb = y + ix * QK4_NL + it * 8;
 
@@ -5830,12 +5855,13 @@ void kernel_mul_mv_iq4_nl_f32_impl(
     float4 qf1, qf2;
 
     for (int ib = ix; ib < nb; ib += 16) {
-
         device const float4 * y4 = (device const float4 *)yb;
-        yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
-
-        for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) {
+        yl[0] = y4[0];
+        yl[1] = y4[4];
+        yl[2] = y4[1];
+        yl[3] = y4[5];
 
+        for (short row = 0; row < nr0; row++) {
             device const block_iq4_nl & xb = x[row*nb + ib];
             device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
 
@@ -5860,7 +5886,6 @@ void kernel_mul_mv_iq4_nl_f32_impl(
             acc1 += acc2;
 
             sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
-
         }
 
         yb += 16 * QK4_NL;
@@ -5868,15 +5893,29 @@ void kernel_mul_mv_iq4_nl_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
-        all_sum = simd_sum(sumf[row]);
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst_f32[first_row + row] = all_sum;
+            dst_f32[first_row + row] = sum_all;
         }
     }
 }
 
-template<typename args_t>
+[[host_name("kernel_mul_mv_iq4_nl_f32")]]
+kernel void kernel_mul_mv_iq4_nl_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_iq4_nl_f32_impl<N_R0_IQ4_NL, N_SG_IQ4_NL, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+}
+
+template<int nr0, int nsg, int nw, typename args_t>
 void kernel_mul_mv_iq4_xs_f32_impl(
         args_t args,
         device const char * src0,
@@ -5892,7 +5931,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
-    const int first_row = (r0 * 2 + sgitg) * 2;
+    const int first_row = (r0 * nsg + sgitg) * nr0;
 
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
@@ -5903,16 +5942,16 @@ void kernel_mul_mv_iq4_xs_f32_impl(
     device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
     device const float        * y = (device const float        *) (src1 + offset1);
 
-    const int ix = tiisg/16;  // 0 or 1
-    const int it = tiisg%16;  // 0...15
-    const int ib = it/2;
-    const int il = it%2;
+    const short ix = tiisg/16;  // 0 or 1
+    const short it = tiisg%16;  // 0...15
+    const short ib = it/2;
+    const short il = it%2;
 
     shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
     float4 yl[4];
-    float sumf[2]={0.f}, all_sum;
+    float sumf[nr0]={0.f};
 
     device const float * yb = y + ix * QK_K + ib * 32 + il * 8;
 
@@ -5923,9 +5962,12 @@ void kernel_mul_mv_iq4_xs_f32_impl(
 
     for (int ibl = ix; ibl < nb; ibl += 2) {
         device const float4 * y4 = (device const float4 *)yb;
-        yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
+        yl[0] = y4[0];
+        yl[1] = y4[4];
+        yl[2] = y4[1];
+        yl[3] = y4[5];
 
-        for (int row = 0; row < 2; ++row) {
+        for (short row = 0; row < nr0; ++row) {
             device const block_iq4_xs & xb = x[row*nb + ibl];
             device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
 
@@ -5949,7 +5991,6 @@ void kernel_mul_mv_iq4_xs_f32_impl(
 
             const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32;
             sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
-
         }
 
         yb += 2 * QK_K;
@@ -5957,54 +5998,14 @@ void kernel_mul_mv_iq4_xs_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
-        all_sum = simd_sum(sumf[row]);
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst_f32[first_row + row] = all_sum;
+            dst_f32[first_row + row] = sum_all;
         }
     }
 }
 
-[[host_name("kernel_mul_mv_iq1_s_f32")]]
-kernel void kernel_mul_mv_iq1_s_f32(
-        constant ggml_metal_kargs_mul_mv & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        uint3  tgpig[[threadgroup_position_in_grid]],
-        ushort tiisg[[thread_index_in_simdgroup]],
-        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-
-    kernel_mul_mv_iq1_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
-}
-
-[[host_name("kernel_mul_mv_iq1_m_f32")]]
-kernel void kernel_mul_mv_iq1_m_f32(
-        constant ggml_metal_kargs_mul_mv & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        uint3  tgpig[[threadgroup_position_in_grid]],
-        ushort tiisg[[thread_index_in_simdgroup]],
-        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-
-    kernel_mul_mv_iq1_m_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
-}
-
-[[host_name("kernel_mul_mv_iq4_nl_f32")]]
-kernel void kernel_mul_mv_iq4_nl_f32(
-        constant ggml_metal_kargs_mul_mv & args,
-        device const char * src0,
-        device const char * src1,
-        device       char * dst,
-        threadgroup  char * shmem [[threadgroup(0)]],
-        uint3  tgpig[[threadgroup_position_in_grid]],
-        ushort tiisg[[thread_index_in_simdgroup]],
-        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
-
-    kernel_mul_mv_iq4_nl_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
-}
-
 [[host_name("kernel_mul_mv_iq4_xs_f32")]]
 kernel void kernel_mul_mv_iq4_xs_f32(
         constant ggml_metal_kargs_mul_mv & args,
@@ -6016,7 +6017,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
         ushort tiisg[[thread_index_in_simdgroup]],
         ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq4_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, N_SG_IQ4_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
 template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
@@ -6660,25 +6661,27 @@ template [[host_name("kernel_mul_mv_id_f16_f32")]]     kernel kernel_mul_mv_id_t
 #if defined(GGML_METAL_USE_BF16)
 template [[host_name("kernel_mul_mv_id_bf16_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<bfloat, bfloat4, float, float4>>>;
 #endif
-template [[host_name("kernel_mul_mv_id_q8_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_q4_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q4_1_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q5_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q5_1_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
-template [[host_name("kernel_mul_mv_id_q2_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_q3_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_q4_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_q5_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_q6_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_iq1_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_iq1_m_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_iq3_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_iq2_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl>>;
-template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl>>;
+template [[host_name("kernel_mul_mv_id_q8_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SG_Q8_0, N_SIMDWIDTH>>>;
+
+template [[host_name("kernel_mul_mv_id_q4_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SG_Q4_0, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q4_1_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, N_SG_Q4_1, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q5_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SG_Q5_0, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q5_1_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SG_Q5_1, N_SIMDWIDTH>>>;
+
+template [[host_name("kernel_mul_mv_id_q2_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl   <N_R0_Q2_K,    N_SG_Q2_K,    N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q3_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl   <N_R0_Q3_K,    N_SG_Q3_K,    N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q4_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl   <N_R0_Q4_K,    N_SG_Q4_K,    N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q5_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl   <N_R0_Q5_K,    N_SG_Q5_K,    N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_q6_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl   <N_R0_Q6_K,    N_SG_Q6_K,    N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_iq1_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl  <N_R0_IQ1_S,   N_SG_IQ1_S,   N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_iq1_m_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl  <N_R0_IQ1_M,   N_SG_IQ1_M,   N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, N_SG_IQ2_XXS, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl <N_R0_IQ2_XS,  N_SG_IQ2_XS,  N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, N_SG_IQ3_XXS, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_iq3_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl  <N_R0_IQ3_S,   N_SG_IQ3_S,   N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_iq2_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl  <N_R0_IQ2_S,   N_SG_IQ2_S,   N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl <N_R0_IQ4_NL,  N_SG_IQ4_NL,  N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl <N_R0_IQ4_XS,  N_SG_IQ4_XS,  N_SIMDWIDTH>>>;
 
 kernel void kernel_pool_2d_max_f32(
         device  const float * src0,