]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : various optimizations + refactoring (#16446)
authorGeorgi Gerganov <redacted>
Tue, 7 Oct 2025 05:21:40 +0000 (08:21 +0300)
committerGitHub <redacted>
Tue, 7 Oct 2025 05:21:40 +0000 (08:21 +0300)
* metal : ssm_scan minor opts

* metal : get_rows optimize

* metal : cpy optimize

* metal : ssm_conv opt

* metal : ssm_scan simplify

* metal : ssm_Scan opt

ggml/src/ggml-metal/ggml-metal-device.cpp
ggml/src/ggml-metal/ggml-metal-device.m
ggml/src/ggml-metal/ggml-metal-impl.h
ggml/src/ggml-metal/ggml-metal-ops.cpp
ggml/src/ggml-metal/ggml-metal.metal

index 819f31c8a300cb0cc2c61789dab6c273817c20a6..d9e92044275ae07ec84371f4f685783973b38f20 100644 (file)
@@ -338,7 +338,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar
     char base[256];
     char name[256];
 
-    snprintf(base, 256, "kernel_ssm_conv_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
+    const char * suffix = "";
+
+    if (op->src[1]->ne[0] % 4 == 0) {
+        suffix = "_4";
+    }
+
+    snprintf(base, 256, "kernel_ssm_conv_%s_%s%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix);
     snprintf(name, 256, "%s", base);
 
     ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
@@ -352,15 +358,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar
 }
 
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op)  {
+    GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
+
     char base[256];
     char name[256];
 
-    if (op->src[3]->ne[0] == 1) {
-        snprintf(base, 256, "kernel_ssm_scan_group_%s", ggml_type_name(op->src[0]->type));
-    } else {
-        snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type));
-    }
-    snprintf(name, 256, "%s", base);
+    const int nsg = (ne00 + 31)/32;
+
+    snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type));
+    snprintf(name, 256, "%s_nsg=%d", base, nsg);
 
     ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
     if (res) {
@@ -369,7 +375,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_librar
 
     res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
 
-    ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
+    ggml_metal_pipeline_set_smem(res, 32*sizeof(float)*nsg);
 
     return res;
 }
index 523f9d71ba14ea96f89ae67b8df5541213b85822..95279730152455a9a2839fbc9b83dea681281194 100644 (file)
@@ -776,9 +776,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
                 };
             }
         case GGML_OP_GET_ROWS:
-            {
-                return op->ne[3] == 1;
-            }
+            return true;
         case GGML_OP_SET_ROWS:
             {
                 if (op->src[0]->type != GGML_TYPE_F32) {
index 88c98423ebec0a9fc06001b0b5d923e04b3b32ed..908e2e1cf5ccb48d6d4306a8ae265acce4bf72cb 100644 (file)
@@ -178,6 +178,7 @@ typedef struct {
 } ggml_metal_kargs_clamp;
 
 typedef struct {
+    int64_t  nk0;
     int64_t  ne00;
     int64_t  ne01;
     int64_t  ne02;
@@ -572,32 +573,45 @@ typedef struct {
     int64_t  n_seq_tokens;
     int64_t  n_seqs;
     uint64_t s_off;
+    uint64_t nb00;
     uint64_t nb01;
     uint64_t nb02;
     uint64_t nb03;
+    uint64_t nb10;
     uint64_t nb11;
     uint64_t nb12;
+    uint64_t ns12;
     uint64_t nb13;
+    uint64_t nb20;
     uint64_t nb21;
+    uint64_t ns21;
     uint64_t nb22;
+    int64_t  ne30;
     uint64_t nb31;
     uint64_t nb41;
     uint64_t nb42;
+    uint64_t ns42;
     uint64_t nb43;
     uint64_t nb51;
     uint64_t nb52;
+    uint64_t ns52;
     uint64_t nb53;
+    uint64_t nb0;
 } ggml_metal_kargs_ssm_scan;
 
 typedef struct {
-    int64_t  ne00;
+    int32_t  ne00t;
+    int32_t  ne00;
     uint64_t nb01;
     uint64_t nb02;
-    int64_t  ne10;
+    uint64_t nb03;
+    int32_t  ne10;
     uint64_t nb10;
     uint64_t nb11;
+    uint64_t nb12;
     uint64_t nb1;
     uint64_t nb2;
+    uint64_t nb3;
 } ggml_metal_kargs_get_rows;
 
 typedef struct {
index e85a223c01dc38cde843fc56b5a03e970d8c69f0..7497d7c1da09839bbc3724a6cf22555ad4b1a4cf 100644 (file)
@@ -577,6 +577,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
         ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
 
         ggml_metal_kargs_cpy args = {
+            /*.nk0  =*/ ne00,
             /*.ne00 =*/ ne00,
             /*.ne01 =*/ ne01,
             /*.ne02 =*/ ne02,
@@ -906,23 +907,31 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
     ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
 
     ggml_metal_kargs_get_rows args = {
-        /*.ne00 =*/ ne00,
-        /*.nb01 =*/ nb01,
-        /*.nb02 =*/ nb02,
-        /*.ne10 =*/ ne10,
-        /*.nb10 =*/ nb10,
-        /*.nb11 =*/ nb11,
-        /*.nb1  =*/ nb1,
-        /*.nb2  =*/ nb2,
+        /*.ne00t =*/ ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00,
+        /*.ne00  =*/ ne00,
+        /*.nb01  =*/ nb01,
+        /*.nb02  =*/ nb02,
+        /*.nb03  =*/ nb03,
+        /*.ne10  =*/ ne10,
+        /*.nb10  =*/ nb10,
+        /*.nb11  =*/ nb11,
+        /*.nb12  =*/ nb12,
+        /*.nb1   =*/ nb1,
+        /*.nb2   =*/ nb2,
+        /*.nb3   =*/ nb3,
     };
 
+    const int nth = std::min(args.ne00t, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+
+    const int nw0 = (args.ne00t + nth - 1)/nth;
+
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
     ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
     ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
     ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         3);
 
-    ggml_metal_encoder_dispatch_threadgroups(enc, ne10, ne11, ne12, 32, 1, 1);
+    ggml_metal_encoder_dispatch_threadgroups(enc, nw0*ne10, ne11, ne12, nth, 1, 1);
 
     return 1;
 }
@@ -1117,7 +1126,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
     ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
     ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
     ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
-    ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
+    ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op),         3);
 
     ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1);
 
@@ -1172,25 +1181,36 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
         /*.n_seq_tokens =*/ n_seq_tokens,
         /*.n_seqs       =*/ n_seqs,
         /*.s_off        =*/ ggml_nelements(op->src[1]) * sizeof(float),
+        /*.nb00         =*/ nb00,
         /*.nb01         =*/ nb01,
         /*.nb02         =*/ nb02,
         /*.nb03         =*/ nb03,
+        /*.nb10         =*/ nb10,
         /*.nb11         =*/ nb11,
         /*.nb12         =*/ nb12,
+        /*.ns12         =*/ nb12/nb10,
         /*.nb13         =*/ nb13,
+        /*.nb20         =*/ nb20,
         /*.nb21         =*/ nb21,
+        /*.ns21         =*/ nb21/nb20,
         /*.nb22         =*/ nb22,
+        /*.ne30         =*/ ne30,
         /*.nb31         =*/ nb31,
         /*.nb41         =*/ nb41,
         /*.nb42         =*/ nb42,
+        /*.ns42         =*/ nb42/nb40,
         /*.nb43         =*/ nb43,
         /*.nb51         =*/ nb51,
         /*.nb52         =*/ nb52,
+        /*.ns52         =*/ nb52/nb50,
         /*.nb53         =*/ nb53,
+        /*.nb0          =*/ nb0,
     };
 
     ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);
 
+    GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+
     const size_t sms = ggml_metal_pipeline_get_smem(pipeline);
 
     ggml_metal_encoder_set_pipeline(enc, pipeline);
@@ -1206,13 +1226,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
 
     ggml_metal_encoder_set_threadgroup_memory_size(enc, sms, 0);
 
-    if (ne30 == 1) {
-        // Mamba-2
-        ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
-    } else {
-        GGML_ASSERT(d_inner == 1);
-        ggml_metal_encoder_dispatch_threadgroups(enc, n_head, n_seqs, 1, d_state, 1, 1);
-    }
+    ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
 
     return 1;
 }
@@ -1273,26 +1287,23 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
 
     GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0);
 
-    // TODO: support
-    //const int32_t nk00 = ne00/ggml_blck_size(op->type);
-    const int32_t nk00 = ne00;
-
-    int nth = 32; // SIMD width
-
-    while (nth < nk00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
-        nth *= 2;
+    int64_t nk0 = ne00;
+    if (ggml_is_quantized(op->src[0]->type)) {
+        nk0 = ne00/16;
+    } else if (ggml_is_quantized(op->type)) {
+        nk0 = ne00/ggml_blck_size(op->type);
     }
 
-    nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
+    int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
 
     // when rows are small, we can batch them together in a single threadgroup
     int nrptg = 1;
 
     // TODO: relax this constraint in the future
     if (ggml_blck_size(op->src[0]->type) == 1 && ggml_blck_size(op->type) == 1) {
-        if (nth > nk00) {
-            nrptg = (nth + nk00 - 1)/nk00;
-            nth   = nk00;
+        if (nth > nk0) {
+            nrptg = (nth + nk0 - 1)/nk0;
+            nth   = nk0;
 
             if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
                 nrptg--;
@@ -1300,10 +1311,11 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
         }
     }
 
-    nth = std::min(nth, nk00);
+    nth = std::min<int>(nth, nk0);
 
     ggml_metal_kargs_cpy args = {
-        /*.ne00 =*/ nk00,
+        /*.nk0  =*/ nk0,
+        /*.ne00 =*/ ne00,
         /*.ne01 =*/ ne01,
         /*.ne02 =*/ ne02,
         /*.ne03 =*/ ne03,
@@ -1321,12 +1333,14 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
         /*.nb3  =*/ nb3,
     };
 
+    const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
+
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
     ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
     ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         2);
 
-    ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, nrptg, 1);
+    ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1);
 
     return 1;
 }
index 96df6f0ce62de18fd0e3c66aad0487a8a39d9725..f454ceada37b373b55085d01e5115735a4dff84c 100644 (file)
@@ -2032,7 +2032,38 @@ kernel void kernel_ssm_conv_f32_f32(
     x[0] = sumf;
 }
 
-// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part
+kernel void kernel_ssm_conv_f32_f32_4(
+        constant ggml_metal_kargs_ssm_conv & args,
+        device const  void * src0,
+        device const  void * src1,
+        device       float * dst,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {
+    const int64_t ir = tgpig.x;
+    const int64_t i2 = tgpig.y;
+    const int64_t i3 = tgpig.z;
+
+    const int64_t nc  = args.ne10;
+  //const int64_t ncs = args.ne00;
+  //const int64_t nr  = args.ne01;
+  //const int64_t n_t = args.ne1;
+  //const int64_t n_s = args.ne2;
+
+    device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
+    device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11);
+    device       float  * x = (device       float  *) ((device       char *) dst  + ir*args.nb0  + i2*args.nb1  + i3*args.nb2);
+
+    float sumf = 0.0f;
+
+    for (int64_t i0 = 0; i0 < nc/4; ++i0) {
+        sumf += dot(s[i0], c[i0]);
+    }
+
+    x[0] = sumf;
+}
+
+// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
 kernel void kernel_ssm_scan_f32(
         constant ggml_metal_kargs_ssm_scan & args,
         device const void * src0,
@@ -2044,219 +2075,88 @@ kernel void kernel_ssm_scan_f32(
         device const void * src6,
         device      float * dst,
         threadgroup float * shared [[threadgroup(0)]],
-        uint3  tgpig[[threadgroup_position_in_grid]],
-        uint3  tpitg[[thread_position_in_threadgroup]],
-        ushort sgitg[[simdgroup_index_in_threadgroup]],
-        ushort tiisg[[thread_index_in_simdgroup]],
-        ushort sgptg[[simdgroups_per_threadgroup]],
-        uint3   tgpg[[threadgroups_per_grid]]) {
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort  sgptg[[simdgroups_per_threadgroup]],
+        uint3    tgpg[[threadgroups_per_grid]]) {
+    constexpr short NW = N_SIMDWIDTH;
 
-    const int64_t i0 = tpitg.x;
-    const int64_t i1 = 0;
-    const int64_t ir = tgpig.x; // current head
-    const int64_t i3 = tgpig.y; // current seq
+    shared[tpitg.x] = 0.0f;
 
-    const uint64_t nb00 = sizeof(float);
-    const uint64_t nb10 = sizeof(float);
-    const uint64_t nb20 = sizeof(float);
+    const int32_t i0 = tpitg.x;
+    const int32_t i1 = tgpig.x;
+    const int32_t ir = tgpig.y; // current head
+    const int32_t i3 = tgpig.z; // current seq
 
-    const int64_t nc  = args.d_state;
-    const int64_t nr  = args.d_inner;
-    const int64_t nh  = args.n_head;
-    const int64_t ng  = args.n_group;
-    const int64_t n_t = args.n_seq_tokens;
+    const int32_t nc  = args.d_state;
+    const int32_t nr  = args.d_inner;
+    const int32_t nh  = args.n_head;
+    const int32_t ng  = args.n_group;
+    const int32_t n_t = args.n_seq_tokens;
 
-    const int64_t s_off = args.s_off;
+    const int32_t s_off = args.s_off;
 
     device const int32_t * ids = (device const int32_t *) src6;
 
     device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
     device       float * s_buff  = (device       float *) ((device       char *) dst  + ir*args.nb02 +      i3*args.nb03 + s_off);
-    const int64_t i = i0 + i1*nc;
-    const int64_t g = ir / (nh / ng); // repeat_interleave
+
+    const int32_t i = i0 + i1*nc;
+    const int32_t g = ir / (nh / ng); // repeat_interleave
+
     float s0 = s0_buff[i];
-    float s  = s_buff[i];
-
-        device const float * A        = (device const float *) ((device const char *) src3 + ir*args.nb31);
-        device const float * x_block  = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
-        device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
-        device const float * B_block  = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43);
-        device const float * C_block  = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53);
-        device       float * y_block  = (device       float *) ((device       char *) dst  + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
-
-    for (int64_t i2 = 0; i2 < n_t; ++i2) {
-        device const float * x  = (device const float *) ((device const char *) x_block + i2*args.nb12);    // {dim, nh, nt, ns}
-        device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21);   // {nh, nt, ns}
-        device const float * B  = (device const float *) ((device const char *) B_block + i2*args.nb42);    // {d_state, ng, nt, ns}
-        device const float * C  = (device const float *) ((device const char *) C_block + i2*args.nb52);    // {d_state, ng, nt, ns}
-        device       float * y  = (device       float *) ((device       char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
-
-        const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
-        const float x_dt = x[0] * dt_soft_plus;
-
-        const float state = (s0 * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
-        s = state;
-
-        // Parallel sum: This relies on the fact that this kernel will be
-        // dispatched with each threadgroup having (d_state, 1, 1) threads which
-        // are subdivided into SIMD groups of size `sgptg`. The goal is to
-        // compute y = sum({state * C[i] for i in range(d_state)}).
-        // To parallelize this effectively, we first use simd_sum over each SIMD
-        // group to compute the sum of each SIMD group, then place the result in
-        // the SIMD group's indexed bucket in the shared memory. We then sum
-        // over the individual group sums to compute the final sum.
-
-        // Computed for each thread
-        float sumf = state * C[i0];
-
-        // Sum the threads in the simd group => simd sum
-        sumf = simd_sum(sumf);
-
-        if (sgptg > 1) {
-
-            // Once per simd group, place the group sum into the shared buffer
-            if (tiisg == 0) {
-                shared[sgitg] = sumf;
-            }
+    float s  = 0.0f;
 
-            // Wait for all threads in the threadgroup to reach this point. This
-            // ensures that all elements of the shared buffer are populated with the
-            // sum of the individual simd groups.
-            threadgroup_barrier(mem_flags::mem_threadgroup);
+    device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {ne30, nh}
 
-            // For simd group 0 at indices < num simd groups, extract the shared
-            // simd sum
-            sumf = 0.0f;
-            if (sgitg == 0) {
-                if (tiisg < sgptg) {
-                    sumf = shared[tiisg];
-                }
-                sumf = simd_sum(sumf);
-                if (tiisg == 0) {
-                    y[0] = sumf;
-                }
-            }
-        } else if (tiisg == 0) {
-            y[0] = sumf;
-        }
+    const float A0 = A[i0%args.ne30];
 
-        // recurse
-        s0 = s;
-    }
+    device const float * x  = (device const float *)((device const char *) src1 + i1*args.nb10  + ir*args.nb11 + i3*args.nb13); // {dim, nh, nt, ns}
+    device const float * dt = (device const float *)((device const char *) src2 + ir*args.nb20  + i3*args.nb22);                // {nh, nt, ns}
+    device const float * B  = (device const float *)((device const char *) src4 +  g*args.nb41  + i3*args.nb43);                // {d_state, ng, nt, ns}
+    device const float * C  = (device const float *)((device const char *) src5 +  g*args.nb51  + i3*args.nb53);                // {d_state, ng, nt, ns}
 
-    // Assign the final state to the output buffer
-    s_buff[i] = s;
-}
+    device float * y = dst + (i1 + ir*(nr) + i3*(n_t*nh*nr)); // {dim, nh, nt, ns}
 
-// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
-kernel void kernel_ssm_scan_group_f32(
-        constant ggml_metal_kargs_ssm_scan & args,
-        device const void * src0,
-        device const void * src1,
-        device const void * src2,
-        device const void * src3,
-        device const void * src4,
-        device const void * src5,
-        device const void * src6,
-        device      float * dst,
-        threadgroup float * shared [[threadgroup(0)]],
-        uint3  tgpig[[threadgroup_position_in_grid]],
-        uint3  tpitg[[thread_position_in_threadgroup]],
-        ushort sgitg[[simdgroup_index_in_threadgroup]],
-        ushort tiisg[[thread_index_in_simdgroup]],
-        ushort sgptg[[simdgroups_per_threadgroup]],
-        uint3   tgpg[[threadgroups_per_grid]]) {
+    for (int i2 = 0; i2 < n_t; i2 += sgptg) {
+        threadgroup_barrier(mem_flags::mem_threadgroup);
 
-    const int64_t i0 = tpitg.x;
-    const int64_t i1 = tgpig.x;
-    const int64_t ir = tgpig.y; // current head
-    const int64_t i3 = tgpig.z; // current seq
+        for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
+            const float dt0  = dt[0];
+            const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0;
+            const float x_dt = x[0] * dtsp;
+            const float dA   = exp(dtsp * A0);
 
-    const uint64_t nb00 = sizeof(float);
-    const uint64_t nb10 = sizeof(float);
-    const uint64_t nb20 = sizeof(float);
+            s = (s0 * dA) + (B[i0] * x_dt);
 
-    const int64_t nc  = args.d_state;
-    const int64_t nr  = args.d_inner;
-    const int64_t nh  = args.n_head;
-    const int64_t ng  = args.n_group;
-    const int64_t n_t = args.n_seq_tokens;
+            const float sumf = simd_sum(s * C[i0]);
 
-    const int64_t s_off = args.s_off;
+            if (tiisg == 0) {
+                shared[t*NW + sgitg] = sumf;
+            }
 
-    device const int32_t * ids = (device const int32_t *) src6;
+            // recurse
+            s0 = s;
 
-    device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
-    device       float * s_buff  = (device       float *) ((device       char *) dst  + ir*args.nb02 +      i3*args.nb03 + s_off);
-    const int64_t i = i0 + i1*nc;
-    const int64_t g = ir / (nh / ng); // repeat_interleave
-    float s0 = s0_buff[i];
-    float s  = s_buff[i];
-
-    device const float * A        = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
-    device const float * x_block  = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
-    device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
-    device const float * B_block  = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43);
-    device const float * C_block  = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53);
-    device       float * y_block  = (device       float *) ((device       char *) dst  + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
-
-    for (int64_t i2 = 0; i2 < n_t; ++i2) {
-        device const float * x  = (device const float *) ((device const char *) x_block  + i2*args.nb12);    // {dim, nh, nt, ns}
-        device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21);    // {nh, nt, ns}
-        device const float * B  = (device const float *) ((device const char *) B_block  + i2*args.nb42);    // {d_state, ng, nt, ns}
-        device const float * C  = (device const float *) ((device const char *) C_block  + i2*args.nb52);    // {d_state, ng, nt, ns}
-        device       float * y  = (device       float *) ((device       char *) y_block  + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
-
-        const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
-        const float x_dt = x[0] * dt_soft_plus;
-        const float dA = exp(dt_soft_plus * A[0]);
-
-        const float state = (s0 * dA) + (B[i0] * x_dt);
-        s = state;
-
-        // Parallel sum: This relies on the fact that this kernel will be
-        // dispatched with each threadgroup having (d_state, 1, 1) threads which
-        // are subdivided into SIMD groups of size `sgptg`. The goal is to
-        // compute y = sum({state * C[i] for i in range(d_state)}).
-        // To parallelize this effectively, we first use simd_sum over each SIMD
-        // group to compute the sum of each SIMD group, then place the result in
-        // the SIMD group's indexed bucket in the shared memory. We then sum
-        // over the individual group sums to compute the final sum.
-
-        // Computed for each thread
-        float sumf = state * C[i0];
-
-        // Sum the threads in the simd group => simd sum
-        sumf = simd_sum(sumf);
-
-        // Once per simd group, place the group sum into the shared buffer
-        if (tiisg == 0) {
-            shared[sgitg] = sumf;
+            x  += args.ns12;
+            dt += args.ns21;
+            B  += args.ns42;
+            C  += args.ns52;
         }
 
-        // Wait for all threads in the threadgroup to reach this point. This
-        // ensures that all elements of the shared buffer are populated with the
-        // sum of the individual simd groups.
         threadgroup_barrier(mem_flags::mem_threadgroup);
 
-        // For simd group 0 at indices < num simd groups, extract the shared
-        // simd sum
-        sumf = 0.0f;
-        if (sgitg == 0) {
-            if (tiisg < sgptg) {
-                sumf = shared[tiisg];
-            }
-            sumf = simd_sum(sumf);
-            if (tiisg == 0) {
-                y[0] = sumf;
-            }
+        const float sumf = simd_sum(shared[sgitg*NW + tiisg]);
+
+        if (tiisg == 0 && i2 + sgitg < n_t) {
+            y[sgitg*nh*nr] = sumf;
         }
 
-        // recurse
-        s0 = s;
+        y += sgptg*nh*nr;
     }
 
-    // Assign the final state to the output buffer
     s_buff[i] = s;
 }
 
@@ -5770,21 +5670,17 @@ kernel void kernel_flash_attn_ext_vec_reduce(
 }
 
 template<typename T0, typename T1>
-kernel void kernel_cpy(
+kernel void kernel_cpy_t_t(
         constant ggml_metal_kargs_cpy & args,
         device  const char * src0,
         device        char * dst,
         uint3   tgpig[[threadgroup_position_in_grid]],
-        uint    tiitg[[thread_index_in_threadgroup]],
-        ushort3 tpitg[[thread_position_in_threadgroup]],
-        ushort3  tptg[[threads_per_threadgroup]]) {
+        ushort  tiitg[[thread_index_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
     const int i03 = tgpig[2];
     const int i02 = tgpig[1];
-    const int i01 = tgpig[0]*tptg.y + tiitg/tptg.x;
-
-    if (i01 >= args.ne01) {
-        return;
-    }
+    const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
+    const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
 
     const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
 
@@ -5795,190 +5691,70 @@ kernel void kernel_cpy(
 
     device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
 
-    for (int64_t i00 = tiitg%tptg.x; i00 < args.ne00; i00 += tptg.x) {
+    for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.ne00; ) {
         device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
         dst_data[i00] = (T1) src[0];
+        break;
     }
 }
 
-typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
+typedef decltype(kernel_cpy_t_t<float, float>) kernel_cpy_t;
 
-template [[host_name("kernel_cpy_f32_f32")]]   kernel kernel_cpy_t kernel_cpy<float,  float>;
-template [[host_name("kernel_cpy_f32_f16")]]   kernel kernel_cpy_t kernel_cpy<float,  half>;
-template [[host_name("kernel_cpy_f32_i32")]]   kernel kernel_cpy_t kernel_cpy<float,  int32_t>;
-template [[host_name("kernel_cpy_i32_f32")]]   kernel kernel_cpy_t kernel_cpy<int32_t, float>;
+template [[host_name("kernel_cpy_f32_f32")]]   kernel kernel_cpy_t kernel_cpy_t_t<float,   float>;
+template [[host_name("kernel_cpy_f32_f16")]]   kernel kernel_cpy_t kernel_cpy_t_t<float,   half>;
+template [[host_name("kernel_cpy_f32_i32")]]   kernel kernel_cpy_t kernel_cpy_t_t<float,   int32_t>;
+template [[host_name("kernel_cpy_i32_f32")]]   kernel kernel_cpy_t kernel_cpy_t_t<int32_t, float>;
 #if defined(GGML_METAL_HAS_BF16)
-template [[host_name("kernel_cpy_f32_bf16")]]  kernel kernel_cpy_t kernel_cpy<float,  bfloat>;
+template [[host_name("kernel_cpy_f32_bf16")]]  kernel kernel_cpy_t kernel_cpy_t_t<float,   bfloat>;
 #endif
-template [[host_name("kernel_cpy_f16_f32")]]   kernel kernel_cpy_t kernel_cpy<half,   float>;
-template [[host_name("kernel_cpy_f16_f16")]]   kernel kernel_cpy_t kernel_cpy<half,   half>;
+template [[host_name("kernel_cpy_f16_f32")]]   kernel kernel_cpy_t kernel_cpy_t_t<half,    float>;
+template [[host_name("kernel_cpy_f16_f16")]]   kernel kernel_cpy_t kernel_cpy_t_t<half,    half>;
 #if defined(GGML_METAL_HAS_BF16)
-template [[host_name("kernel_cpy_bf16_f32")]]  kernel kernel_cpy_t kernel_cpy<bfloat, float>;
-template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy<bfloat, bfloat>;
+template [[host_name("kernel_cpy_bf16_f32")]]  kernel kernel_cpy_t kernel_cpy_t_t<bfloat,  float>;
+template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<bfloat,  bfloat>;
 #endif
 
-// TODO: templetify these kernels
-kernel void kernel_cpy_f32_q8_0(
+template<short QK,
+         typename block_q,
+         void (*quantize_func)(device const float *, device block_q &)>
+kernel void kernel_cpy_f32_q(
         constant ggml_metal_kargs_cpy & args,
         device const char * src0,
-        device       char * dst,
+        device char * dst,
         uint3   tgpig[[threadgroup_position_in_grid]],
-        ushort3 tpitg[[thread_position_in_threadgroup]],
-        ushort3   ntg[[threads_per_threadgroup]]) {
-    const int i03 = tgpig[2];
-    const int i02 = tgpig[1];
-    const int i01 = tgpig[0];
-
-    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
-
-    const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
-    const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
-    const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
-    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK8_0;
-
-    device block_q8_0 * dst_data = (device block_q8_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
-
-    for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) {
-        device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
-
-        quantize_q8_0(src, dst_data[i00/QK8_0]);
-    }
-}
-
-kernel void kernel_cpy_f32_q4_0(
-        constant ggml_metal_kargs_cpy & args,
-        device const char * src0,
-        device       char * dst,
-        uint3   tgpig[[threadgroup_position_in_grid]],
-        ushort3 tpitg[[thread_position_in_threadgroup]],
-        ushort3   ntg[[threads_per_threadgroup]]) {
-    const int i03 = tgpig[2];
-    const int i02 = tgpig[1];
-    const int i01 = tgpig[0];
-
-    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
-
-    const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
-    const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
-    const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
-    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_0;
-
-    device block_q4_0 * dst_data = (device block_q4_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
-
-    for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) {
-        device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
-
-        quantize_q4_0(src, dst_data[i00/QK4_0]);
-    }
-}
-
-kernel void kernel_cpy_f32_q4_1(
-        constant ggml_metal_kargs_cpy & args,
-        device const char * src0,
-        device       char * dst,
-        uint3   tgpig[[threadgroup_position_in_grid]],
-        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort  tiitg[[thread_index_in_threadgroup]],
         ushort3   ntg[[threads_per_threadgroup]]) {
     const int i03 = tgpig[2];
     const int i02 = tgpig[1];
-    const int i01 = tgpig[0];
+    const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
+    const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
 
     const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
 
     const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
     const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
     const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
-    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_1;
-
-    device block_q4_1 * dst_data = (device block_q4_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
-
-    for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) {
-        device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
-
-        quantize_q4_1(src, dst_data[i00/QK4_1]);
-    }
-}
-
-kernel void kernel_cpy_f32_q5_0(
-        constant ggml_metal_kargs_cpy & args,
-        device const char * src0,
-        device       char * dst,
-        uint3   tgpig[[threadgroup_position_in_grid]],
-        ushort3 tpitg[[thread_position_in_threadgroup]],
-        ushort3   ntg[[threads_per_threadgroup]]) {
-    const int i03 = tgpig[2];
-    const int i02 = tgpig[1];
-    const int i01 = tgpig[0];
-
-    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
+    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK;
 
-    const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
-    const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
-    const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
-    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_0;
+    device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
 
-    device block_q5_0 * dst_data = (device block_q5_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
+    for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {
+        device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00);
 
-    for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) {
-        device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
+        quantize_func(src, dst_data[i00]);
 
-        quantize_q5_0(src, dst_data[i00/QK5_0]);
+        break;
     }
 }
 
-kernel void kernel_cpy_f32_q5_1(
-        constant ggml_metal_kargs_cpy & args,
-        device const char * src0,
-        device       char * dst,
-        uint3   tgpig[[threadgroup_position_in_grid]],
-        ushort3 tpitg[[thread_position_in_threadgroup]],
-        ushort3   ntg[[threads_per_threadgroup]]) {
-    const int i03 = tgpig[2];
-    const int i02 = tgpig[1];
-    const int i01 = tgpig[0];
+typedef decltype(kernel_cpy_f32_q<QK8_0,  block_q8_0,  quantize_q8_0>)  cpy_f_q_t;
 
-    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
-
-    const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
-    const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
-    const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
-    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_1;
-
-    device block_q5_1 * dst_data = (device block_q5_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
-
-    for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) {
-        device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
-
-        quantize_q5_1(src, dst_data[i00/QK5_1]);
-    }
-}
-
-kernel void kernel_cpy_f32_iq4_nl(
-        constant ggml_metal_kargs_cpy & args,
-        device const char * src0,
-        device       char * dst,
-        uint3   tgpig[[threadgroup_position_in_grid]],
-        ushort3 tpitg[[thread_position_in_threadgroup]],
-        ushort3   ntg[[threads_per_threadgroup]]) {
-    const int i03 = tgpig[2];
-    const int i02 = tgpig[1];
-    const int i01 = tgpig[0];
-
-    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
-
-    const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
-    const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
-    const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
-    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_NL;
-
-    device block_iq4_nl * dst_data = (device block_iq4_nl *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
-
-    for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) {
-        device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
-
-        quantize_iq4_nl(src, dst_data[i00/QK4_NL]);
-    }
-}
+template [[host_name("kernel_cpy_f32_q8_0")]]   kernel cpy_f_q_t kernel_cpy_f32_q<QK8_0,  block_q8_0,   quantize_q8_0>;
+template [[host_name("kernel_cpy_f32_q4_0")]]   kernel cpy_f_q_t kernel_cpy_f32_q<QK4_0,  block_q4_0,   quantize_q4_0>;
+template [[host_name("kernel_cpy_f32_q4_1")]]   kernel cpy_f_q_t kernel_cpy_f32_q<QK4_1,  block_q4_1,   quantize_q4_1>;
+template [[host_name("kernel_cpy_f32_q5_0")]]   kernel cpy_f_q_t kernel_cpy_f32_q<QK5_0,  block_q5_0,   quantize_q5_0>;
+template [[host_name("kernel_cpy_f32_q5_1")]]   kernel cpy_f_q_t kernel_cpy_f32_q<QK5_1,  block_q5_1,   quantize_q5_1>;
+template [[host_name("kernel_cpy_f32_iq4_nl")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_NL, block_iq4_nl, quantize_iq4_nl>;
 
 template<typename T4x4, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
 kernel void kernel_cpy_q_f32(
@@ -5986,11 +5762,12 @@ kernel void kernel_cpy_q_f32(
         device  const char * src0,
         device        char * dst,
         uint3   tgpig[[threadgroup_position_in_grid]],
-        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort  tiitg[[thread_index_in_threadgroup]],
         ushort3   ntg[[threads_per_threadgroup]]) {
     const int i03 = tgpig[2];
     const int i02 = tgpig[1];
-    const int i01 = tgpig[0];
+    const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
+    const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
 
     const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
 
@@ -6002,10 +5779,12 @@ kernel void kernel_cpy_q_f32(
     device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
     device       T4x4    * dst_data = (device       T4x4    *)(dst  +  i3*args.nb3  +  i2*args.nb2  +  i1*args.nb1 + i0*args.nb0);
 
-    for (int64_t i00 = tpitg.x; i00 < args.ne00/16; i00 += ntg.x) {
+    for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {
         T4x4 temp;
         dequantize_func(src_data + i00/nl, i00%nl, temp);
         dst_data[i00] = temp;
+
+        break;
     }
 }
 
@@ -7765,66 +7544,60 @@ kernel void kernel_mul_mv_mxfp4_f32(
 template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
 kernel void kernel_get_rows_q(
         constant ggml_metal_kargs_get_rows & args,
-        device const  void * src0,
-        device const  void * src1,
-        device       float * dst,
-        uint3                tgpig[[threadgroup_position_in_grid]],
-        uint                 tiitg[[thread_index_in_threadgroup]],
-        uint3                tptg [[threads_per_threadgroup]]) {
-    const int64_t i10 = tgpig.x;
-    const int64_t i11 = tgpig.y;
+        device const void * src0,
+        device const void * src1,
+        device       void * dst,
+        uint3               tgpig[[threadgroup_position_in_grid]],
+        ushort              tiitg[[thread_index_in_threadgroup]],
+        ushort3             ntg  [[threads_per_threadgroup]]) {
+    const int32_t iw0 = tgpig.x/args.ne10;
+    const int32_t i10 = tgpig.x%args.ne10;
+    const int32_t i11 = tgpig.y;
+    const int32_t i12 = tgpig.z;
+
+    const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];
 
-    const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
+    const int32_t i02 = i11;
+    const int32_t i03 = i12;
 
-    const int64_t i02 = i11;
+    auto psrc = (device const block_q *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 +   r*args.nb01);
+    auto pdst = (device      float4x4 *) ((      device char *) dst  + i12*args.nb3  + i11*args.nb2  + i10*args.nb1);
 
-    for (int64_t ind = tiitg; ind < args.ne00/16; ind += tptg.x) {
+    for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {
         float4x4 temp;
-        dequantize_func(((device const block_q *) ((const device char *) src0 + r*args.nb01 + i02*args.nb02)) + ind/nl, ind%nl, temp);
-        *(((device float4x4 *) ((device char *) dst + i11*args.nb2 + i10*args.nb1)) + ind) = temp;
+        dequantize_func(psrc + ind/nl, ind%nl, temp);
+        pdst[ind] = temp;
+
+        break;
     }
 }
 
-template<typename T>
+template<typename T0, typename T>
 kernel void kernel_get_rows_f(
         constant ggml_metal_kargs_get_rows & args,
-        device const  void * src0,
-        device const  void * src1,
-        device       float * dst,
-        uint3                tgpig[[threadgroup_position_in_grid]],
-        uint                 tiitg[[thread_index_in_threadgroup]],
-        uint3                tptg [[threads_per_threadgroup]]) {
-    const int64_t i10 = tgpig.x;
-    const int64_t i11 = tgpig.y;
-
-    const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
-
-    const int64_t i02 = i11;
+        device const void * src0,
+        device const void * src1,
+        device       void * dst,
+        uint3               tgpig[[threadgroup_position_in_grid]],
+        ushort              tiitg[[thread_index_in_threadgroup]],
+        ushort3             ntg [[threads_per_threadgroup]]) {
+    const int32_t iw0 = tgpig.x/args.ne10;
+    const int32_t i10 = tgpig.x%args.ne10;
+    const int32_t i11 = tgpig.y;
+    const int32_t i12 = tgpig.z;
 
-    for (int ind = tiitg; ind < args.ne00; ind += tptg.x) {
-        ((      device float *) ((      device char *)  dst + i11*args.nb2  + i10*args.nb1))[ind] =
-        ((const device T     *) ((const device char *) src0 + i02*args.nb02 +  r*args.nb01))[ind];
-    }
-}
+    const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];
 
-kernel void kernel_get_rows_i32(
-        constant ggml_metal_kargs_get_rows & args,
-        device const  void * src0,
-        device const  void * src1,
-        device     int32_t * dst,
-        uint3                tgpig[[threadgroup_position_in_grid]],
-        uint                 tiitg[[thread_index_in_threadgroup]],
-        uint3                tptg [[threads_per_threadgroup]]) {
-    const int64_t i10 = tgpig.x;
-    const int64_t i11 = tgpig.y;
+    const int32_t i02 = i11;
+    const int32_t i03 = i12;
 
-    const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
+    auto psrc = (const device T0 *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 +   r*args.nb01);
+    auto pdst = (      device T  *) ((      device char *)  dst + i12*args.nb3  + i11*args.nb2  + i10*args.nb1);
 
-    const int64_t i02 = i11;
+    for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {
+        pdst[ind] = psrc[ind];
 
-    for (int ind = tiitg; ind < args.ne00; ind += tptg.x) {
-        ((      device int32_t *) ((      device char *) dst  + i11*args.nb2 + i10*args.nb1))[ind] =
-        ((const device int32_t *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind];
+        break;
     }
 }
 
@@ -8310,12 +8083,13 @@ kernel void kernel_mul_mm_id(
 // get rows
 //
 
-typedef decltype(kernel_get_rows_f<float>) get_rows_f_t;
+typedef decltype(kernel_get_rows_f<float, float>) get_rows_f_t;
 
-template [[host_name("kernel_get_rows_f32")]]  kernel get_rows_f_t kernel_get_rows_f<float>;
-template [[host_name("kernel_get_rows_f16")]]  kernel get_rows_f_t kernel_get_rows_f<half>;
+template [[host_name("kernel_get_rows_f32")]]  kernel get_rows_f_t kernel_get_rows_f<float, float>;
+template [[host_name("kernel_get_rows_f16")]]  kernel get_rows_f_t kernel_get_rows_f<half,  float>;
+template [[host_name("kernel_get_rows_i32")]]  kernel get_rows_f_t kernel_get_rows_f<int32_t, int32_t>;
 #if defined(GGML_METAL_HAS_BF16)
-template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat>;
+template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat, float>;
 #endif
 
 typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;