]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
metal : refactor kernel args into structs (llama/10238)
authorGeorgi Gerganov <redacted>
Sun, 17 Nov 2024 09:23:01 +0000 (11:23 +0200)
committerGeorgi Gerganov <redacted>
Wed, 20 Nov 2024 19:00:08 +0000 (21:00 +0200)
* metal : add kernel arg structs (wip)

* metal : fattn args

ggml-ci

* metal : cont + avoid potential int overflow [no ci]

* metal : mul mat struct (wip)

* cont : mul mat vec

* cont : pass by reference

* cont : args is first argument

* cont : use char ptr

* cont : shmem style

* cont : thread counters style

* cont : mul mm id

ggml-ci

* cont : int safety + register optimizations

ggml-ci

* metal : GGML_OP_CONCAT

ggml-ci

* metal : GGML_OP_ADD, GGML_OP_SUB, GGML_OP_MUL, GGML_OP_DIV

* metal : GGML_OP_REPEAT

* metal : GGML_OP_CPY

* metal : GGML_OP_RMS_NORM

* metal : GGML_OP_NORM

* metal : add TODOs for rest of ops

* ggml : add ggml-metal-impl.h

ggml-ci

ggml/src/ggml-metal/CMakeLists.txt
ggml/src/ggml-metal/ggml-metal-impl.h [new file with mode: 0644]
ggml/src/ggml-metal/ggml-metal.m
ggml/src/ggml-metal/ggml-metal.metal

index e0992c7449b6252a3d2a7f1e9cbe47be1033847d..b237d79f47ddb635cfb1b569039bb590bdfe99f8 100644 (file)
@@ -25,9 +25,10 @@ if (GGML_METAL_USE_BF16)
     add_compile_definitions(GGML_METAL_USE_BF16)
 endif()
 
-# copy ggml-common.h and ggml-metal.metal to bin directory
-configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h    COPYONLY)
-configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)
+# copy metal files to bin directory
+configure_file(../ggml-common.h  ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h     COPYONLY)
+configure_file(ggml-metal.metal  ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal  COPYONLY)
+configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY)
 
 if (GGML_METAL_EMBED_LIBRARY)
     enable_language(ASM)
@@ -36,24 +37,27 @@ if (GGML_METAL_EMBED_LIBRARY)
 
     set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/../ggml-common.h")
     set(METALLIB_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal")
+    set(METALLIB_IMPL   "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h")
 
     file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/autogenerated")
 
     # merge ggml-common.h and ggml-metal.metal into a single file
-    set(METALLIB_EMBED_ASM    "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.s")
-    set(METALLIB_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal")
+    set(METALLIB_EMBED_ASM        "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.s")
+    set(METALLIB_SOURCE_EMBED     "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal")
+    set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp")
 
     add_custom_command(
         OUTPUT ${METALLIB_EMBED_ASM}
         COMMAND echo "Embedding Metal library"
-        COMMAND sed -e '/__embed_ggml-common.h__/r ${METALLIB_COMMON}' -e '/__embed_ggml-common.h__/d' < ${METALLIB_SOURCE} > ${METALLIB_SOURCE_EMBED}
+        COMMAND sed -e '/__embed_ggml-common.h__/r         ${METALLIB_COMMON}' -e '/__embed_ggml-common.h__/d'         < ${METALLIB_SOURCE}           > ${METALLIB_SOURCE_EMBED_TMP}
+        COMMAND sed -e '/\#include \"ggml-metal-impl.h\"/r ${METALLIB_IMPL}'   -e '/\#include \"ggml-metal-impl.h\"/d' < ${METALLIB_SOURCE_EMBED_TMP} > ${METALLIB_SOURCE_EMBED}
         COMMAND echo ".section __DATA,__ggml_metallib"          >  ${METALLIB_EMBED_ASM}
         COMMAND echo ".globl _ggml_metallib_start"              >> ${METALLIB_EMBED_ASM}
         COMMAND echo "_ggml_metallib_start:"                    >> ${METALLIB_EMBED_ASM}
         COMMAND echo ".incbin \\\"${METALLIB_SOURCE_EMBED}\\\"" >> ${METALLIB_EMBED_ASM}
         COMMAND echo ".globl _ggml_metallib_end"                >> ${METALLIB_EMBED_ASM}
         COMMAND echo "_ggml_metallib_end:"                      >> ${METALLIB_EMBED_ASM}
-        DEPENDS ggml-metal.metal ../ggml-common.h
+        DEPENDS ../ggml-common.h ggml-metal.metal ggml-metal-impl.h
         COMMENT "Generate assembly for embedded Metal library"
     )
 
diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h
new file mode 100644 (file)
index 0000000..53c1354
--- /dev/null
@@ -0,0 +1,249 @@
+#ifndef GGML_METAL_IMPL
+#define GGML_METAL_IMPL
+
+// kernel argument structs
+//
+// - element counters (e.g. ne00) typically use int32_t to reduce register usage
+//   however, be careful from int overflows when using those in the kernel implementation
+//
+// - strides (e.g. nb00) use uint64_t
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne10;
+    int32_t  ne11;
+    int32_t  ne12;
+    int32_t  ne13;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    int32_t  ne0;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+    int32_t  dim;
+} ggml_metal_kargs_concat;
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne10;
+    int32_t  ne11;
+    int32_t  ne12;
+    int32_t  ne13;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    int32_t  ne0;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+    uint64_t offs;
+} ggml_metal_kargs_bin;
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne0;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+} ggml_metal_kargs_repeat;
+
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int64_t  ne02;
+    int64_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int64_t  ne0;
+    int64_t  ne1;
+    int64_t  ne2;
+    int64_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+} ggml_metal_kargs_cpy;
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne0;
+    int32_t  ne1;
+    int32_t  ne2;
+    int32_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+    int32_t  n_past;
+    int32_t  n_dims;
+    int32_t  n_ctx_orig;
+    float    freq_base;
+    float    freq_scale;
+    float    ext_factor;
+    float    attn_factor;
+    float    beta_fast;
+    float    beta_slow;
+} ggml_metal_kargs_rope;
+
+typedef struct {
+    int32_t  ne01;
+    int32_t  ne02;
+    int32_t  ne03;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne11;
+    int32_t  ne_12_2; // assume K and V are same shape
+    int32_t  ne_12_3;
+    uint64_t nb_12_1;
+    uint64_t nb_12_2;
+    uint64_t nb_12_3;
+    uint64_t nb31;
+    int32_t  ne1;
+    int32_t  ne2;
+    float    scale;
+    float    max_bias;
+    float    m0;
+    float    m1;
+    uint16_t n_head_log2;
+    float    logit_softcap;
+} ggml_metal_kargs_flash_attn_ext;
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne02;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne12;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    int32_t  ne0;
+    int32_t  ne1;
+    int16_t  r2;
+    int16_t  r3;
+} ggml_metal_kargs_mul_mm;
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne10;
+    int32_t  ne11;
+    int32_t  ne12;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    int32_t  ne0;
+    int32_t  ne1;
+    int16_t  r2;
+    int16_t  r3;
+} ggml_metal_kargs_mul_mv;
+
+typedef struct {
+    int32_t  nei0;
+    int32_t  nei1;
+    uint64_t nbi1;
+    int32_t  ne00;
+    int32_t  ne02;
+    uint64_t nb01;
+    uint64_t nb02;
+    int32_t  ne11;
+    int32_t  ne12;
+    int32_t  ne13;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    int32_t  ne0;
+    int32_t  ne1;
+} ggml_metal_kargs_mul_mm_id;
+
+typedef struct {
+    int32_t  nei0;
+    int32_t  nei1;
+    uint64_t nbi1;
+    int32_t  ne00;
+    int32_t  ne01;
+    int32_t  ne02;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    int32_t  ne10;
+    int32_t  ne11;
+    int32_t  ne12;
+    int32_t  ne13;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    int32_t  ne0;
+    int32_t  ne1;
+    uint64_t nb1;
+} ggml_metal_kargs_mul_mv_id;
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne00_4;
+    uint64_t nb01;
+    float    eps;
+} ggml_metal_kargs_norm;
+
+typedef struct {
+    int32_t  ne00;
+    int32_t  ne00_4;
+    uint64_t nb01;
+    float    eps;
+} ggml_metal_kargs_rms_norm;
+
+#endif // GGML_METAL_IMPL
index 18aa1e7b9088a5f24d4af48737f6eedbfe19ec0c..cc0726fb7c1c72def9aec3a05b30516cea00faec 100644 (file)
@@ -2,6 +2,7 @@
 
 #import "ggml-impl.h"
 #import "ggml-backend-impl.h"
+#import "ggml-metal-impl.h"
 
 #import <Foundation/Foundation.h>
 
@@ -1193,35 +1194,39 @@ static void ggml_metal_encode_node(
 
                 const int32_t dim = ((const int32_t *) dst->op_params)[0];
 
+                ggml_metal_kargs_concat args = {
+                    /*.ne00 =*/ ne00,
+                    /*.ne01 =*/ ne01,
+                    /*.ne02 =*/ ne02,
+                    /*.ne03 =*/ ne03,
+                    /*.nb00 =*/ nb00,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.nb03 =*/ nb03,
+                    /*.ne10 =*/ ne10,
+                    /*.ne11 =*/ ne11,
+                    /*.ne12 =*/ ne12,
+                    /*.ne13 =*/ ne13,
+                    /*.nb10 =*/ nb10,
+                    /*.nb11 =*/ nb11,
+                    /*.nb12 =*/ nb12,
+                    /*.nb13 =*/ nb13,
+                    /*.ne0  =*/ ne0,
+                    /*.ne1  =*/ ne1,
+                    /*.ne2  =*/ ne2,
+                    /*.ne3  =*/ ne3,
+                    /*.nb0  =*/ nb0,
+                    /*.nb1  =*/ nb1,
+                    /*.nb2  =*/ nb2,
+                    /*.nb3  =*/ nb3,
+                    /*.dim  =*/ dim,
+                };
+
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
-                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
-                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
-                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
-                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
-                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
-                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
-                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
-                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
-                [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
-                [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
-                [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
-                [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
-                [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
-                [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
-                [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
-                [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
-                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:19];
-                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:20];
-                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:21];
-                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:22];
-                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:23];
-                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:24];
-                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:25];
-                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:26];
-                [encoder setBytes:&dim  length:sizeof(dim)  atIndex:27];
+                [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:3];
 
                 const int nth = MIN(1024, ne0);
 
@@ -1239,8 +1244,6 @@ static void ggml_metal_encode_node(
 
                 bool bcast_row = false;
 
-                int64_t nb = ne00; // used by the "row" kernels
-
                 id<MTLComputePipelineState> pipeline = nil;
 
                 if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
@@ -1249,7 +1252,6 @@ static void ggml_metal_encode_node(
                     // src1 is a row
                     GGML_ASSERT(ne11 == 1);
 
-                    nb = ne00 / 4;
                     switch (dst->op) {
                         case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
                         case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break;
@@ -1269,36 +1271,39 @@ static void ggml_metal_encode_node(
                     }
                 }
 
+                ggml_metal_kargs_bin args = {
+                    /*.ne00 =*/ ne00,
+                    /*.ne01 =*/ ne01,
+                    /*.ne02 =*/ ne02,
+                    /*.ne03 =*/ ne03,
+                    /*.nb00 =*/ nb00,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.nb03 =*/ nb03,
+                    /*.ne10 =*/ ne10,
+                    /*.ne11 =*/ ne11,
+                    /*.ne12 =*/ ne12,
+                    /*.ne13 =*/ ne13,
+                    /*.nb10 =*/ nb10,
+                    /*.nb11 =*/ nb11,
+                    /*.nb12 =*/ nb12,
+                    /*.nb13 =*/ nb13,
+                    /*.ne0  =*/ ne0,
+                    /*.ne1  =*/ ne1,
+                    /*.ne2  =*/ ne2,
+                    /*.ne3  =*/ ne3,
+                    /*.nb0  =*/ nb0,
+                    /*.nb1  =*/ nb1,
+                    /*.nb2  =*/ nb2,
+                    /*.nb3  =*/ nb3,
+                    /*.offs =*/ offs,
+                };
+
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
-                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
-                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
-                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
-                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
-                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
-                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
-                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
-                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
-                [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
-                [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
-                [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
-                [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
-                [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
-                [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
-                [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
-                [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
-                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:19];
-                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:20];
-                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:21];
-                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:22];
-                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:23];
-                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:24];
-                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:25];
-                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:26];
-                [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
-                [encoder setBytes:&nb   length:sizeof(nb)   atIndex:28];
+                [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:3];
 
                 if (bcast_row) {
                     const int64_t n = ggml_nelements(dst)/4;
@@ -1322,25 +1327,29 @@ static void ggml_metal_encode_node(
                     default: GGML_ABORT("fatal error");
                 }
 
+                ggml_metal_kargs_repeat args = {
+                    /*.ne00 =*/ ne00,
+                    /*.ne01 =*/ ne01,
+                    /*.ne02 =*/ ne02,
+                    /*.ne03 =*/ ne03,
+                    /*.nb00 =*/ nb00,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.nb03 =*/ nb03,
+                    /*.ne0  =*/ ne0,
+                    /*.ne1  =*/ ne1,
+                    /*.ne2  =*/ ne2,
+                    /*.ne3  =*/ ne3,
+                    /*.nb0  =*/ nb0,
+                    /*.nb1  =*/ nb1,
+                    /*.nb2  =*/ nb2,
+                    /*.nb3  =*/ nb3,
+                };
+
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
-                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
-                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
-                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
-                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
-                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
-                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
-                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
-                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:10];
-                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:11];
-                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:12];
-                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:13];
-                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:14];
-                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:15];
-                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:16];
-                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:17];
+                [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
 
                 const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
 
@@ -1369,25 +1378,29 @@ static void ggml_metal_encode_node(
 
                     const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
 
+                    ggml_metal_kargs_cpy args = {
+                        /*.ne00 =*/ ne00,
+                        /*.ne01 =*/ ne01,
+                        /*.ne02 =*/ ne02,
+                        /*.ne03 =*/ ne03,
+                        /*.nb00 =*/ nb00,
+                        /*.nb01 =*/ nb01,
+                        /*.nb02 =*/ nb02,
+                        /*.nb03 =*/ nb03,
+                        /*.ne0  =*/ ne0,
+                        /*.ne1  =*/ ne1,
+                        /*.ne2  =*/ ne2,
+                        /*.ne3  =*/ ne3,
+                        /*.nb0  =*/ nb0,
+                        /*.nb1  =*/ nb1,
+                        /*.nb2  =*/ nb2,
+                        /*.nb3  =*/ nb3,
+                    };
+
                     [encoder setComputePipelineState:pipeline];
-                    [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
-                    [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
-                    [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
-                    [encoder setBytes:&ne01    length:sizeof( int64_t) atIndex:3];
-                    [encoder setBytes:&ne02    length:sizeof( int64_t) atIndex:4];
-                    [encoder setBytes:&ne03    length:sizeof( int64_t) atIndex:5];
-                    [encoder setBytes:&nb00    length:sizeof(uint64_t) atIndex:6];
-                    [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:7];
-                    [encoder setBytes:&nb02    length:sizeof(uint64_t) atIndex:8];
-                    [encoder setBytes:&nb03    length:sizeof(uint64_t) atIndex:9];
-                    [encoder setBytes:&ne0     length:sizeof( int64_t) atIndex:10];
-                    [encoder setBytes:&ne1     length:sizeof( int64_t) atIndex:11];
-                    [encoder setBytes:&ne2     length:sizeof( int64_t) atIndex:12];
-                    [encoder setBytes:&ne3     length:sizeof( int64_t) atIndex:13];
-                    [encoder setBytes:&nb0     length:sizeof(uint64_t) atIndex:14];
-                    [encoder setBytes:&nb1     length:sizeof(uint64_t) atIndex:15];
-                    [encoder setBytes:&nb2     length:sizeof(uint64_t) atIndex:16];
-                    [encoder setBytes:&nb3     length:sizeof(uint64_t) atIndex:17];
+                    [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
 
                     const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
 
@@ -1396,35 +1409,39 @@ static void ggml_metal_encode_node(
 
                 const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
 
+                ggml_metal_kargs_bin args = {
+                    /*.ne00 =*/ ne00,
+                    /*.ne01 =*/ ne01,
+                    /*.ne02 =*/ ne02,
+                    /*.ne03 =*/ ne03,
+                    /*.nb00 =*/ nb00,
+                    /*.nb01 =*/ pnb1,
+                    /*.nb02 =*/ pnb2,
+                    /*.nb03 =*/ pnb3,
+                    /*.ne10 =*/ ne10,
+                    /*.ne11 =*/ ne11,
+                    /*.ne12 =*/ ne12,
+                    /*.ne13 =*/ ne13,
+                    /*.nb10 =*/ nb10,
+                    /*.nb11 =*/ nb11,
+                    /*.nb12 =*/ nb12,
+                    /*.nb13 =*/ nb13,
+                    /*.ne0  =*/ ne0,
+                    /*.ne1  =*/ ne1,
+                    /*.ne2  =*/ ne2,
+                    /*.ne3  =*/ ne3,
+                    /*.nb0  =*/ nb0,
+                    /*.nb1  =*/ pnb1,
+                    /*.nb2  =*/ pnb2,
+                    /*.nb3  =*/ pnb3,
+                    /*.offs =*/ offs,
+                };
+
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
-                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
-                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
-                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
-                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
-                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
-                [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
-                [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
-                [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
-                [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
-                [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
-                [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
-                [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
-                [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
-                [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
-                [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
-                [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
-                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:19];
-                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:20];
-                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:21];
-                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:22];
-                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:23];
-                [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
-                [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
-                [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
-                [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
+                [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:3];
 
                 const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
 
@@ -1465,10 +1482,10 @@ static void ggml_metal_encode_node(
                 memcpy(&max, ((const int32_t *) dst->op_params) + 1, sizeof(float));
 
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0   offset:offs_src0 atIndex:0];
-                [encoder setBuffer:id_dst    offset:offs_dst  atIndex:1];
-                [encoder setBytes:&min length:sizeof(min) atIndex:2];
-                [encoder setBytes:&max length:sizeof(max) atIndex:3];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                [encoder setBytes:&min   length:sizeof(min) atIndex:2];
+                [encoder setBytes:&max   length:sizeof(max) atIndex:3];
 
                 const int64_t n = ggml_nelements(dst);
 
@@ -1640,6 +1657,7 @@ static void ggml_metal_encode_node(
 
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
 
+                // TODO: add ggml_metal_kargs struct
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
@@ -1715,6 +1733,8 @@ static void ggml_metal_encode_node(
                 const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
                 const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
+                // TODO: add ggml_metal_kargs struct
+                // TODO: optimize (see https://github.com/ggerganov/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6)
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0   atIndex:0];
                 if (id_src1) {
@@ -1731,6 +1751,7 @@ static void ggml_metal_encode_node(
                 [encoder setBytes:&m0          length:sizeof(m0)          atIndex:8];
                 [encoder setBytes:&m1          length:sizeof(m1)          atIndex:9];
                 [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
+
                 [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
 
                 [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
@@ -1747,6 +1768,7 @@ static void ggml_metal_encode_node(
                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
                 }
 
+                // TODO: add ggml_metal_kargs struct
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
@@ -1771,6 +1793,7 @@ static void ggml_metal_encode_node(
 
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
 
+                // TODO: add ggml_metal_kargs struct
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0    atIndex:0];
                 [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1];
@@ -1841,6 +1864,7 @@ static void ggml_metal_encode_node(
 
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
 
+                // TODO: add ggml_metal_kargs struct
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -1959,24 +1983,29 @@ static void ggml_metal_encode_node(
                                 default: GGML_ABORT("MUL MAT-MAT not implemented");
                             }
 
+                            ggml_metal_kargs_mul_mm args = {
+                                /*.ne00 =*/ ne00,
+                                /*.ne02 =*/ ne02,
+                                /*.nb01 =*/ nb01,
+                                /*.nb02 =*/ nb02,
+                                /*.nb03 =*/ nb03,
+                                /*.ne12 =*/ ne12,
+                                /*.nb10 =*/ nb10,
+                                /*.nb11 =*/ nb11,
+                                /*.nb12 =*/ nb12,
+                                /*.nb13 =*/ nb13,
+                                /*.ne0  =*/ ne0,
+                                /*.ne1  =*/ ne1,
+                                /*.r2   =*/ r2,
+                                /*.r3   =*/ r3,
+                            };
+
                             [encoder setComputePipelineState:pipeline];
-                            [encoder setBuffer:id_src0 offset:offs_src0    atIndex:0];
-                            [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1];
-                            [encoder setBuffer:id_dst  offset:offs_dst     atIndex:2];
-                            [encoder setBytes:&ne00    length:sizeof(ne00) atIndex:3];
-                            [encoder setBytes:&ne02    length:sizeof(ne02) atIndex:4];
-                            [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:5];
-                            [encoder setBytes:&nb02    length:sizeof(nb02) atIndex:6];
-                            [encoder setBytes:&nb03    length:sizeof(nb03) atIndex:7];
-                            [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:8];
-                            [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:9];
-                            [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:10];
-                            [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:11];
-                            [encoder setBytes:&nb13    length:sizeof(nb13) atIndex:12];
-                            [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:13];
-                            [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:14];
-                            [encoder setBytes:&r2      length:sizeof(r2)   atIndex:15];
-                            [encoder setBytes:&r3      length:sizeof(r3)   atIndex:16];
+                            [encoder setBytes:&args    length:sizeof(args) atIndex:0];
+                            [encoder setBuffer:id_src0 offset:offs_src0    atIndex:1];
+                            [encoder setBuffer:id_src1 offset:offs_src1    atIndex:2];
+                            [encoder setBuffer:id_dst  offset:offs_dst     atIndex:3];
+
                             [encoder setThreadgroupMemoryLength:8192 atIndex:0];
                             [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
                         } else {
@@ -2154,28 +2183,32 @@ static void ggml_metal_encode_node(
                                     }
                             };
 
+                            ggml_metal_kargs_mul_mv args = {
+                                /*.ne00 =*/ ne00,
+                                /*.ne01 =*/ ne01,
+                                /*.ne02 =*/ ne02,
+                                /*.nb00 =*/ nb00,
+                                /*.nb01 =*/ nb01,
+                                /*.nb02 =*/ nb02,
+                                /*.nb03 =*/ nb03,
+                                /*.ne10 =*/ ne10,
+                                /*.ne11 =*/ ne11,
+                                /*.ne12 =*/ ne12,
+                                /*.nb10 =*/ nb10,
+                                /*.nb11 =*/ nb11,
+                                /*.nb12 =*/ nb12,
+                                /*.nb13 =*/ nb13,
+                                /*.ne0  =*/ ne0,
+                                /*.ne1  =*/ ne1,
+                                /*.r2   =*/ r2,
+                                /*.r3   =*/ r3,
+                            };
+
                             [encoder setComputePipelineState:pipeline];
-                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                            [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
-                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
-                            [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
-                            [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
-                            [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
-                            [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
-                            [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
-                            [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
-                            [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
-                            [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
-                            [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
-                            [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:13];
-                            [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:14];
-                            [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:15];
-                            [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:16];
-                            [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:17];
-                            [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:18];
-                            [encoder setBytes:&r2   length:sizeof(r2)   atIndex:19];
-                            [encoder setBytes:&r3   length:sizeof(r3)   atIndex:20];
+                            [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                            [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:3];
 
                             if (src0t == GGML_TYPE_Q4_0  || src0t == GGML_TYPE_Q4_1  || src0t == GGML_TYPE_Q5_0 ||
                                 src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K ||
@@ -2288,27 +2321,30 @@ static void ggml_metal_encode_node(
                         default: GGML_ABORT("MUL_MAT_ID not implemented");
                     }
 
+                    ggml_metal_kargs_mul_mm_id args = {
+                        /*.nei0 =*/ ne20,
+                        /*.nei1 =*/ ne21,
+                        /*.nbi1 =*/ nb21,
+                        /*.ne00 =*/ ne00,
+                        /*.ne02 =*/ ne02,
+                        /*.nb01 =*/ nb01,
+                        /*.nb02 =*/ nb02,
+                        /*.ne11 =*/ ne11,
+                        /*.ne12 =*/ ne12,
+                        /*.ne13 =*/ ne13,
+                        /*.nb10 =*/ nb10,
+                        /*.nb11 =*/ nb11,
+                        /*.nb12 =*/ nb12,
+                        /*.ne0  =*/ ne0,
+                        /*.ne1  =*/ ne1,
+                    };
+
                     [encoder setComputePipelineState:pipeline];
-                    [encoder setBuffer:id_src0 offset:offs_src0    atIndex:0];
-                    [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1];
-                    [encoder setBuffer:id_dst  offset:offs_dst     atIndex:2];
-                    [encoder setBuffer:id_src2 offset:offs_src2    atIndex:3];
-                    [encoder setBytes:&ne20    length:sizeof(ne20) atIndex:4];
-                    [encoder setBytes:&ne21    length:sizeof(ne21) atIndex:5];
-                    [encoder setBytes:&nb21    length:sizeof(nb21) atIndex:6];
-                    [encoder setBytes:&ne00    length:sizeof(ne00) atIndex:7];
-                    [encoder setBytes:&ne02    length:sizeof(ne02) atIndex:8];
-                    [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:9];
-                    [encoder setBytes:&nb02    length:sizeof(nb02) atIndex:10];
-                    [encoder setBytes:&ne11    length:sizeof(ne11) atIndex:11];
-                    [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:12];
-                    [encoder setBytes:&ne13    length:sizeof(ne13) atIndex:13];
-                    [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:14];
-                    [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:15];
-                    [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:16];
-                    [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:17];
-                    [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:18];
-                    [encoder setBytes:&nb1     length:sizeof(nb1)  atIndex:19];
+                    [encoder setBytes:&args    length:sizeof(args) atIndex:0];
+                    [encoder setBuffer:id_src0 offset:offs_src0    atIndex:1];
+                    [encoder setBuffer:id_src1 offset:offs_src1    atIndex:2];
+                    [encoder setBuffer:id_dst  offset:offs_dst     atIndex:3];
+                    [encoder setBuffer:id_src2 offset:offs_src2    atIndex:4];
 
                     [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
 
@@ -2467,30 +2503,34 @@ static void ggml_metal_encode_node(
                         GGML_ASSERT(ne00 >= nth0*nth1);
                     }
 
+                    ggml_metal_kargs_mul_mv_id args = {
+                        /*.nei0 =*/ ne20,
+                        /*.nei1 =*/ ne21,
+                        /*.nbi1 =*/ nb21,
+                        /*.ne00 =*/ ne00,
+                        /*.ne01 =*/ ne01,
+                        /*.ne02 =*/ ne02,
+                        /*.nb00 =*/ nb00,
+                        /*.nb01 =*/ nb01,
+                        /*.nb02 =*/ nb02,
+                        /*.ne10 =*/ ne10,
+                        /*.ne11 =*/ ne11,
+                        /*.ne12 =*/ ne12,
+                        /*.ne13 =*/ ne13,
+                        /*.nb10 =*/ nb10,
+                        /*.nb11 =*/ nb11,
+                        /*.nb12 =*/ nb12,
+                        /*.ne0  =*/ ne0,
+                        /*.ne1  =*/ ne1,
+                        /*.nb1  =*/ nb1,
+                    };
+
                     [encoder setComputePipelineState:pipeline];
-                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                    [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
-                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-                    [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
-                    [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
-                    [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
-                    [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
-                    [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
-                    [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8];
-                    [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9];
-                    [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10];
-                    [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11];
-                    [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12];
-                    [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13];
-                    [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14];
-                    [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15];
-                    [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16];
-                    [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17];
-                    [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18];
-                    [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19];
-                    [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:20];
-                    [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:21];
-                    [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:22];
+                    [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                    [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
+                    [encoder setBuffer:id_dst  offset:offs_dst  atIndex:3];
+                    [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
 
                     const int64_t _ne1 = 1;
                     const int tgz = dst_rows;
@@ -2563,6 +2603,7 @@ static void ggml_metal_encode_node(
                     default: GGML_ABORT("not implemented");
                 }
 
+                // TODO: add ggml_metal_kargs struct
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0     offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_src1     offset:offs_src1 atIndex:1];
@@ -2586,20 +2627,28 @@ static void ggml_metal_encode_node(
                 float eps;
                 memcpy(&eps, dst->op_params, sizeof(float));
 
+                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
+
                 int nth = 32; // SIMD width
 
-                while (nth < ne00/4 && nth < 1024) {
+                while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
                     nth *= 2;
                 }
 
-                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
+                nth = MIN(nth, ne00/4);
+
+                ggml_metal_kargs_rms_norm args = {
+                    /*.ne00   =*/ ne00,
+                    /*.ne00_4 =*/ ne00/4,
+                    /*.nb01   =*/ nb01,
+                    /*.eps    =*/ eps,
+                };
 
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
-                [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
-                [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
-                [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:3];
-                [encoder setBytes:&eps     length:sizeof(   float) atIndex:4];
+                [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+
                 [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
 
                 const int64_t nrows = ggml_nrows(src0);
@@ -2624,6 +2673,7 @@ static void ggml_metal_encode_node(
 
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
 
+                // TODO: add ggml_metal_kargs struct
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0  offset:offs_src0        atIndex:0];
                 [encoder setBuffer:id_dst   offset:offs_dst         atIndex:1];
@@ -2641,22 +2691,35 @@ static void ggml_metal_encode_node(
             } break;
         case GGML_OP_NORM:
             {
+                GGML_ASSERT(ne00 % 4 == 0);
                 GGML_ASSERT(ggml_is_contiguous_1(src0));
 
                 float eps;
                 memcpy(&eps, dst->op_params, sizeof(float));
 
-                const int nth = MIN(256, ne00);
-
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline;
 
+                int nth = 32; // SIMD width
+
+                while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
+                    nth *= 2;
+                }
+
+                nth = MIN(nth, ne00/4);
+
+                ggml_metal_kargs_norm args = {
+                    /*.ne00   =*/ ne00,
+                    /*.ne00_4 =*/ ne00/4,
+                    /*.nb01   =*/ nb01,
+                    /*.eps    =*/ eps,
+                };
+
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
-                [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
-                [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
-                [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:3];
-                [encoder setBytes:&eps     length:sizeof(   float) atIndex:4];
-                [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0];
+                [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+
+                [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
 
                 const int64_t nrows = ggml_nrows(src0);
 
@@ -2706,40 +2769,44 @@ static void ggml_metal_encode_node(
                     };
                 }
 
+                ggml_metal_kargs_rope args = {
+                    /*.ne00        =*/ ne00,
+                    /*.ne01        =*/ ne01,
+                    /*.ne02        =*/ ne02,
+                    /*.ne03        =*/ ne03,
+                    /*.nb00        =*/ nb00,
+                    /*.nb01        =*/ nb01,
+                    /*.nb02        =*/ nb02,
+                    /*.nb03        =*/ nb03,
+                    /*.ne0         =*/ ne0,
+                    /*.ne1         =*/ ne1,
+                    /*.ne2         =*/ ne2,
+                    /*.ne3         =*/ ne3,
+                    /*.nb0         =*/ nb0,
+                    /*.nb1         =*/ nb1,
+                    /*.nb2         =*/ nb2,
+                    /*.nb3         =*/ nb3,
+                    /*.n_past      =*/ n_past,
+                    /*.n_dims      =*/ n_dims,
+                    /*.n_ctx_orig  =*/ n_ctx_orig,
+                    /*.freq_base   =*/ freq_base,
+                    /*.freq_scale  =*/ freq_scale,
+                    /*.ext_factor  =*/ ext_factor,
+                    /*.attn_factor =*/ attn_factor,
+                    /*.beta_fast   =*/ beta_fast,
+                    /*.beta_slow   =*/ beta_slow,
+                };
+
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0     offset:offs_src0        atIndex:0];
-                [encoder setBuffer:id_src1     offset:offs_src1        atIndex:1];
+                [encoder setBytes:&args length:sizeof(args)     atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0     atIndex:1];
+                [encoder setBuffer:id_src1 offset:offs_src1     atIndex:2];
                 if (id_src2 != nil) {
-                    [encoder setBuffer:id_src2 offset:offs_src2        atIndex:2];
+                    [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
                 } else {
-                    [encoder setBuffer:id_src0 offset:offs_src0        atIndex:2];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
                 }
-                [encoder setBuffer:id_dst      offset:offs_dst         atIndex:3];
-                [encoder setBytes:&ne00        length:sizeof( int64_t) atIndex:4];
-                [encoder setBytes:&ne01        length:sizeof( int64_t) atIndex:5];
-                [encoder setBytes:&ne02        length:sizeof( int64_t) atIndex:6];
-                [encoder setBytes:&ne03        length:sizeof( int64_t) atIndex:7];
-                [encoder setBytes:&nb00        length:sizeof(uint64_t) atIndex:8];
-                [encoder setBytes:&nb01        length:sizeof(uint64_t) atIndex:9];
-                [encoder setBytes:&nb02        length:sizeof(uint64_t) atIndex:10];
-                [encoder setBytes:&nb03        length:sizeof(uint64_t) atIndex:11];
-                [encoder setBytes:&ne0         length:sizeof( int64_t) atIndex:12];
-                [encoder setBytes:&ne1         length:sizeof( int64_t) atIndex:13];
-                [encoder setBytes:&ne2         length:sizeof( int64_t) atIndex:14];
-                [encoder setBytes:&ne3         length:sizeof( int64_t) atIndex:15];
-                [encoder setBytes:&nb0         length:sizeof(uint64_t) atIndex:16];
-                [encoder setBytes:&nb1         length:sizeof(uint64_t) atIndex:17];
-                [encoder setBytes:&nb2         length:sizeof(uint64_t) atIndex:18];
-                [encoder setBytes:&nb3         length:sizeof(uint64_t) atIndex:19];
-                [encoder setBytes:&n_past      length:sizeof(     int) atIndex:20];
-                [encoder setBytes:&n_dims      length:sizeof(     int) atIndex:21];
-                [encoder setBytes:&n_ctx_orig  length:sizeof(     int) atIndex:22];
-                [encoder setBytes:&freq_base   length:sizeof(   float) atIndex:23];
-                [encoder setBytes:&freq_scale  length:sizeof(   float) atIndex:24];
-                [encoder setBytes:&ext_factor  length:sizeof(   float) atIndex:25];
-                [encoder setBytes:&attn_factor length:sizeof(   float) atIndex:26];
-                [encoder setBytes:&beta_fast   length:sizeof(   float) atIndex:27];
-                [encoder setBytes:&beta_slow   length:sizeof(   float) atIndex:28];
+                [encoder setBuffer:id_dst  offset:offs_dst      atIndex:4];
 
                 [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
             } break;
@@ -2796,6 +2863,7 @@ static void ggml_metal_encode_node(
                     default: GGML_ABORT("fatal error");
                 };
 
+                // TODO: add ggml_metal_kargs struct
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src1 offset:offs_src1       atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst        atIndex:1];
@@ -2836,6 +2904,7 @@ static void ggml_metal_encode_node(
 
                 const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
 
+                // TODO: add ggml_metal_kargs struct
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
@@ -2870,6 +2939,7 @@ static void ggml_metal_encode_node(
 
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
 
+                // TODO: add ggml_metal_kargs struct
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
@@ -2906,6 +2976,7 @@ static void ggml_metal_encode_node(
 
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
 
+                // TODO: add ggml_metal_kargs struct
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_dst  offset:offs_dst    atIndex:0];
                 [encoder setBytes:&ne0   length:sizeof(ne0)   atIndex:1];
@@ -2927,6 +2998,7 @@ static void ggml_metal_encode_node(
 
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
 
+                // TODO: add ggml_metal_kargs struct
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
@@ -2965,6 +3037,7 @@ static void ggml_metal_encode_node(
                     default: GGML_ABORT("fatal error");
                 };
 
+                // TODO: add ggml_metal_kargs struct
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0     offset:offs_src0        atIndex:0];
                 [encoder setBuffer:id_dst      offset:offs_dst         atIndex:1];
@@ -2983,6 +3056,7 @@ static void ggml_metal_encode_node(
 
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
 
+                // TODO: add ggml_metal_kargs struct
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0   atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst    atIndex:1];
@@ -3224,37 +3298,41 @@ static void ggml_metal_encode_node(
                     }
                 }
 
+                ggml_metal_kargs_flash_attn_ext args = {
+                    /*.ne01          =*/ ne01,
+                    /*.ne02          =*/ ne02,
+                    /*.ne03          =*/ ne03,
+                    /*.nb01          =*/ nb01,
+                    /*.nb02          =*/ nb02,
+                    /*.nb03          =*/ nb03,
+                    /*.ne11          =*/ ne11,
+                    /*.ne_12_2       =*/ ne12,
+                    /*.ne_12_3       =*/ ne13,
+                    /*.nb_12_1       =*/ nb11,
+                    /*.nb_12_2       =*/ nb12,
+                    /*.nb_12_3       =*/ nb13,
+                    /*.nb31          =*/ nb31,
+                    /*.ne1           =*/ ne1,
+                    /*.ne2           =*/ ne2,
+                    /*.scale         =*/ scale,
+                    /*.max_bias      =*/ max_bias,
+                    /*.m0            =*/ m0,
+                    /*.m1            =*/ m1,
+                    /*.n_head_log2   =*/ n_head_log2,
+                    /*.logit_softcap =*/ logit_softcap,
+                };
+
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0     offset:offs_src0           atIndex:0];
-                [encoder setBuffer:id_src1     offset:offs_src1           atIndex:1];
-                [encoder setBuffer:id_src2     offset:offs_src2           atIndex:2];
+                [encoder setBytes:&args length:sizeof(args)     atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0     atIndex:1];
+                [encoder setBuffer:id_src1 offset:offs_src1     atIndex:2];
+                [encoder setBuffer:id_src2 offset:offs_src2     atIndex:3];
                 if (id_src3) {
-                    [encoder setBuffer:id_src3     offset:offs_src3           atIndex:3];
+                    [encoder setBuffer:id_src3 offset:offs_src3 atIndex:4];
                 } else {
-                    [encoder setBuffer:id_src0     offset:offs_src0           atIndex:3];
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4];
                 }
-                [encoder setBuffer:id_dst        offset:offs_dst              atIndex:4];
-                [encoder setBytes:&ne01          length:sizeof( int64_t)      atIndex:5];
-                [encoder setBytes:&ne02          length:sizeof( int64_t)      atIndex:6];
-                [encoder setBytes:&ne03          length:sizeof( int64_t)      atIndex:7];
-                [encoder setBytes:&nb01          length:sizeof(uint64_t)      atIndex:8];
-                [encoder setBytes:&nb02          length:sizeof(uint64_t)      atIndex:9];
-                [encoder setBytes:&nb03          length:sizeof(uint64_t)      atIndex:10];
-                [encoder setBytes:&ne11          length:sizeof( int64_t)      atIndex:11];
-                [encoder setBytes:&ne12          length:sizeof( int64_t)      atIndex:12];
-                [encoder setBytes:&ne13          length:sizeof( int64_t)      atIndex:13];
-                [encoder setBytes:&nb11          length:sizeof(uint64_t)      atIndex:14];
-                [encoder setBytes:&nb12          length:sizeof(uint64_t)      atIndex:15];
-                [encoder setBytes:&nb13          length:sizeof(uint64_t)      atIndex:16];
-                [encoder setBytes:&nb31          length:sizeof(uint64_t)      atIndex:17];
-                [encoder setBytes:&ne1           length:sizeof( int64_t)      atIndex:18];
-                [encoder setBytes:&ne2           length:sizeof( int64_t)      atIndex:19];
-                [encoder setBytes:&scale         length:sizeof(   float)      atIndex:20];
-                [encoder setBytes:&max_bias      length:sizeof(   float)      atIndex:21];
-                [encoder setBytes:&m0            length:sizeof(m0)            atIndex:22];
-                [encoder setBytes:&m1            length:sizeof(m1)            atIndex:23];
-                [encoder setBytes:&n_head_log2   length:sizeof(n_head_log2)   atIndex:24];
-                [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:25];
+                [encoder setBuffer:id_dst offset:offs_dst       atIndex:5];
 
                 if (!use_vec_kernel) {
                     // half8x8 kernel
@@ -3389,25 +3467,29 @@ static void ggml_metal_encode_node(
                     default: GGML_ABORT("not implemented");
                 }
 
+                ggml_metal_kargs_cpy args = {
+                    /*.ne00 =*/ ne00,
+                    /*.ne01 =*/ ne01,
+                    /*.ne02 =*/ ne02,
+                    /*.ne03 =*/ ne03,
+                    /*.nb00 =*/ nb00,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.nb03 =*/ nb03,
+                    /*.ne0  =*/ ne0,
+                    /*.ne1  =*/ ne1,
+                    /*.ne2  =*/ ne2,
+                    /*.ne3  =*/ ne3,
+                    /*.nb0  =*/ nb0,
+                    /*.nb1  =*/ nb1,
+                    /*.nb2  =*/ nb2,
+                    /*.nb3  =*/ nb3,
+                };
+
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
-                [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
-                [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
-                [encoder setBytes:&ne01    length:sizeof( int64_t) atIndex:3];
-                [encoder setBytes:&ne02    length:sizeof( int64_t) atIndex:4];
-                [encoder setBytes:&ne03    length:sizeof( int64_t) atIndex:5];
-                [encoder setBytes:&nb00    length:sizeof(uint64_t) atIndex:6];
-                [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:7];
-                [encoder setBytes:&nb02    length:sizeof(uint64_t) atIndex:8];
-                [encoder setBytes:&nb03    length:sizeof(uint64_t) atIndex:9];
-                [encoder setBytes:&ne0     length:sizeof( int64_t) atIndex:10];
-                [encoder setBytes:&ne1     length:sizeof( int64_t) atIndex:11];
-                [encoder setBytes:&ne2     length:sizeof( int64_t) atIndex:12];
-                [encoder setBytes:&ne3     length:sizeof( int64_t) atIndex:13];
-                [encoder setBytes:&nb0     length:sizeof(uint64_t) atIndex:14];
-                [encoder setBytes:&nb1     length:sizeof(uint64_t) atIndex:15];
-                [encoder setBytes:&nb2     length:sizeof(uint64_t) atIndex:16];
-                [encoder setBytes:&nb3     length:sizeof(uint64_t) atIndex:17];
+                [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
 
                 [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
             } break;
@@ -3452,6 +3534,7 @@ static void ggml_metal_encode_node(
                 const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
                 const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
 
+                // TODO: add ggml_metal_kargs struct
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0       atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst        atIndex:1];
index 8c7fcb11303b80745266369d66a036f3dce75830..86fdf1c18cfb692f82a6c2400d6712777e8c9f27 100644 (file)
@@ -6,6 +6,7 @@ __embed_ggml-common.h__
 // TODO: this should not be a relative path, but can't figure out how to set Metal include paths in Package.swift
 #include "../ggml-common.h"
 #endif
+#include "ggml-metal-impl.h"
 
 #include <metal_stdlib>
 
@@ -497,240 +498,131 @@ enum ggml_sort_order {
 // pros: works for non-contiguous tensors, supports broadcast across all dims
 // cons: not very efficient
 kernel void kernel_add(
+        constant ggml_metal_kargs_bin & args,
         device const char * src0,
         device const char * src1,
         device       char * dst,
-        constant  int64_t & ne00,
-        constant  int64_t & ne01,
-        constant  int64_t & ne02,
-        constant  int64_t & ne03,
-        constant uint64_t & nb00,
-        constant uint64_t & nb01,
-        constant uint64_t & nb02,
-        constant uint64_t & nb03,
-        constant  int64_t & ne10,
-        constant  int64_t & ne11,
-        constant  int64_t & ne12,
-        constant  int64_t & ne13,
-        constant uint64_t & nb10,
-        constant uint64_t & nb11,
-        constant uint64_t & nb12,
-        constant uint64_t & nb13,
-        constant  int64_t & ne0,
-        constant  int64_t & ne1,
-        constant  int64_t & ne2,
-        constant  int64_t & ne3,
-        constant uint64_t & nb0,
-        constant uint64_t & nb1,
-        constant uint64_t & nb2,
-        constant uint64_t & nb3,
-        constant  int64_t & offs,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = tgpig.z;
-    const int64_t i02 = tgpig.y;
-    const int64_t i01 = tgpig.x;
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig.z;
+    const int i02 = tgpig.y;
+    const int i01 = tgpig.x;
 
-    const int64_t i13 = i03 % ne13;
-    const int64_t i12 = i02 % ne12;
-    const int64_t i11 = i01 % ne11;
+    const int i13 = i03%args.ne13;
+    const int i12 = i02%args.ne12;
+    const int i11 = i01%args.ne11;
 
-    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
-    device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
-    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1  + offs;
+    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
+    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
+    device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs;
 
-    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
-        const int i10 = i0 % ne10;
-        *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10));
+    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+        const int i10 = i0%args.ne10;
+        *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) + *((device float *)(src1_ptr + i10*args.nb10));
     }
 }
 
 kernel void kernel_sub(
+        constant ggml_metal_kargs_bin & args,
         device const char * src0,
         device const char * src1,
         device       char * dst,
-        constant  int64_t & ne00,
-        constant  int64_t & ne01,
-        constant  int64_t & ne02,
-        constant  int64_t & ne03,
-        constant uint64_t & nb00,
-        constant uint64_t & nb01,
-        constant uint64_t & nb02,
-        constant uint64_t & nb03,
-        constant  int64_t & ne10,
-        constant  int64_t & ne11,
-        constant  int64_t & ne12,
-        constant  int64_t & ne13,
-        constant uint64_t & nb10,
-        constant uint64_t & nb11,
-        constant uint64_t & nb12,
-        constant uint64_t & nb13,
-        constant  int64_t & ne0,
-        constant  int64_t & ne1,
-        constant  int64_t & ne2,
-        constant  int64_t & ne3,
-        constant uint64_t & nb0,
-        constant uint64_t & nb1,
-        constant uint64_t & nb2,
-        constant uint64_t & nb3,
-        constant  int64_t & offs,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = tgpig.z;
-    const int64_t i02 = tgpig.y;
-    const int64_t i01 = tgpig.x;
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig.z;
+    const int i02 = tgpig.y;
+    const int i01 = tgpig.x;
 
-    const int64_t i13 = i03 % ne13;
-    const int64_t i12 = i02 % ne12;
-    const int64_t i11 = i01 % ne11;
+    const int i13 = i03%args.ne13;
+    const int i12 = i02%args.ne12;
+    const int i11 = i01%args.ne11;
 
-    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
-    device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
-    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1  + offs;
+    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
+    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
+    device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1  + args.offs;
 
-    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
-        const int i10 = i0 % ne10;
-        *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) - *((device float *)(src1_ptr + i10*nb10));
+    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+        const int i10 = i0%args.ne10;
+        *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10));
     }
 }
 
 kernel void kernel_mul(
+        constant ggml_metal_kargs_bin & args,
         device const char * src0,
         device const char * src1,
         device       char * dst,
-        constant  int64_t & ne00,
-        constant  int64_t & ne01,
-        constant  int64_t & ne02,
-        constant  int64_t & ne03,
-        constant uint64_t & nb00,
-        constant uint64_t & nb01,
-        constant uint64_t & nb02,
-        constant uint64_t & nb03,
-        constant  int64_t & ne10,
-        constant  int64_t & ne11,
-        constant  int64_t & ne12,
-        constant  int64_t & ne13,
-        constant uint64_t & nb10,
-        constant uint64_t & nb11,
-        constant uint64_t & nb12,
-        constant uint64_t & nb13,
-        constant  int64_t & ne0,
-        constant  int64_t & ne1,
-        constant  int64_t & ne2,
-        constant  int64_t & ne3,
-        constant uint64_t & nb0,
-        constant uint64_t & nb1,
-        constant uint64_t & nb2,
-        constant uint64_t & nb3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = tgpig.z;
-    const int64_t i02 = tgpig.y;
-    const int64_t i01 = tgpig.x;
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig.z;
+    const int i02 = tgpig.y;
+    const int i01 = tgpig.x;
 
-    const int64_t i13 = i03 % ne13;
-    const int64_t i12 = i02 % ne12;
-    const int64_t i11 = i01 % ne11;
+    const int i13 = i03%args.ne13;
+    const int i12 = i02%args.ne12;
+    const int i11 = i01%args.ne11;
 
-    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
-    device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
-    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;
+    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
+    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
+    device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1;
 
-    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
-        const int i10 = i0 % ne10;
-        *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10));
+    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+        const int i10 = i0%args.ne10;
+        *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
     }
 }
 
 kernel void kernel_div(
+        constant ggml_metal_kargs_bin & args,
         device const char * src0,
         device const char * src1,
         device       char * dst,
-        constant  int64_t & ne00,
-        constant  int64_t & ne01,
-        constant  int64_t & ne02,
-        constant  int64_t & ne03,
-        constant uint64_t & nb00,
-        constant uint64_t & nb01,
-        constant uint64_t & nb02,
-        constant uint64_t & nb03,
-        constant  int64_t & ne10,
-        constant  int64_t & ne11,
-        constant  int64_t & ne12,
-        constant  int64_t & ne13,
-        constant uint64_t & nb10,
-        constant uint64_t & nb11,
-        constant uint64_t & nb12,
-        constant uint64_t & nb13,
-        constant  int64_t & ne0,
-        constant  int64_t & ne1,
-        constant  int64_t & ne2,
-        constant  int64_t & ne3,
-        constant uint64_t & nb0,
-        constant uint64_t & nb1,
-        constant uint64_t & nb2,
-        constant uint64_t & nb3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = tgpig.z;
-    const int64_t i02 = tgpig.y;
-    const int64_t i01 = tgpig.x;
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig.z;
+    const int i02 = tgpig.y;
+    const int i01 = tgpig.x;
 
-    const int64_t i13 = i03 % ne13;
-    const int64_t i12 = i02 % ne12;
-    const int64_t i11 = i01 % ne11;
+    const int i13 = i03%args.ne13;
+    const int i12 = i02%args.ne12;
+    const int i11 = i01%args.ne11;
 
-    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
-    device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
-    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1;
+    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
+    device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
+    device       char * dst_ptr  = dst  + i03*args.nb3  + i02*args.nb2  + i01*args.nb1;
 
-    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
-        const int i10 = i0 % ne10;
-        *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10));
+    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+        const int i10 = i0%args.ne10;
+        *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
     }
 }
 
 template<typename T>
 kernel void kernel_repeat(
+        constant ggml_metal_kargs_repeat & args,
         device const char * src0,
         device       char * dst,
-        constant  int64_t & ne00,
-        constant  int64_t & ne01,
-        constant  int64_t & ne02,
-        constant  int64_t & ne03,
-        constant uint64_t & nb00,
-        constant uint64_t & nb01,
-        constant uint64_t & nb02,
-        constant uint64_t & nb03,
-        constant  int64_t & ne0,
-        constant  int64_t & ne1,
-        constant  int64_t & ne2,
-        constant  int64_t & ne3,
-        constant uint64_t & nb0,
-        constant uint64_t & nb1,
-        constant uint64_t & nb2,
-        constant uint64_t & nb3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t i3 = tgpig.z;
-    const int64_t i2 = tgpig.y;
-    const int64_t i1 = tgpig.x;
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i3 = tgpig.z;
+    const int i2 = tgpig.y;
+    const int i1 = tgpig.x;
 
-    const int64_t i03 = i3 % ne03;
-    const int64_t i02 = i2 % ne02;
-    const int64_t i01 = i1 % ne01;
+    const int i03 = i3%args.ne03;
+    const int i02 = i2%args.ne02;
+    const int i01 = i1%args.ne01;
 
-    device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
-    device       char * dst_ptr  = dst  +  i3*nb3  +  i2*nb2  +  i1*nb1 ;
+    device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
+    device       char * dst_ptr  = dst  +  i3*args.nb3  +  i2*args.nb2  +  i1*args.nb1;
 
-    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
-        const int i00 = i0 % ne00;
-        *((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00));
+    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+        const int i00 = i0%args.ne00;
+        *((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00));
     }
 }
 
@@ -744,38 +636,42 @@ template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat
 // assumption: src1 is a row
 // broadcast src1 into src0
 kernel void kernel_add_row(
+        constant ggml_metal_kargs_bin & args,
         device const float4 * src0,
         device const float4 * src1,
         device       float4 * dst,
-        constant   uint64_t & nb [[buffer(28)]],
         uint tpig[[thread_position_in_grid]]) {
+    const uint nb = args.ne00/4;
     dst[tpig] = src0[tpig] + src1[tpig % nb];
 }
 
 kernel void kernel_sub_row(
+        constant ggml_metal_kargs_bin & args,
         device const float4 * src0,
         device const float4 * src1,
         device       float4 * dst,
-        constant   uint64_t & nb [[buffer(28)]],
         uint tpig[[thread_position_in_grid]]) {
+    const uint nb = args.ne00/4;
     dst[tpig] = src0[tpig] - src1[tpig % nb];
 }
 
 kernel void kernel_mul_row(
+        constant ggml_metal_kargs_bin & args,
         device const float4 * src0,
         device const float4 * src1,
         device       float4 * dst,
-        constant   uint64_t & nb  [[buffer(28)]],
         uint tpig[[thread_position_in_grid]]) {
+    const uint nb = args.ne00/4;
     dst[tpig] = src0[tpig] * src1[tpig % nb];
 }
 
 kernel void kernel_div_row(
+        constant ggml_metal_kargs_bin & args,
         device const float4 * src0,
         device const float4 * src1,
         device       float4 * dst,
-        constant   uint64_t & nb  [[buffer(28)]],
         uint tpig[[thread_position_in_grid]]) {
+    const uint nb = args.ne00/4;
     dst[tpig] = src0[tpig] / src1[tpig % nb];
 }
 
@@ -1345,102 +1241,112 @@ kernel void kernel_ssm_scan_f32(
 }
 
 kernel void kernel_norm(
-        device const  void * src0,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant  uint64_t & nb01,
-        constant     float & eps,
-        threadgroup float  * sum [[threadgroup(0)]],
-        uint tgpig[[threadgroup_position_in_grid]],
-        uint tpitg[[thread_position_in_threadgroup]],
-        uint   ntg[[threads_per_threadgroup]]) {
-    device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
-    // MEAN
-    // parallel sum
-    sum[tpitg] = 0.0f;
-    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
-        sum[tpitg] += x[i00];
+        constant ggml_metal_kargs_norm & args,
+        device const char * src0,
+        device       char * dst,
+        threadgroup float * shmem_f32 [[threadgroup(0)]],
+        uint   tgpig[[threadgroup_position_in_grid]],
+        ushort tpitg[[thread_position_in_threadgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort   ntg[[threads_per_threadgroup]]) {
+    if (sgitg == 0) {
+        shmem_f32[tiisg] = 0.0f;
+    }
+
+    device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
+
+    float4 sumf4(0.0f);
+
+    float sumf = 0.0f;
+
+    for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+        sumf4 += x[i00];
     }
-    // reduce
+    sumf = sumf4[0] + sumf4[1] + sumf4[2] + sumf4[3];
+    sumf = simd_sum(sumf);
+
     threadgroup_barrier(mem_flags::mem_threadgroup);
-    for (uint i = ntg/2; i > 0; i /= 2) {
-        if (tpitg < i) {
-            sum[tpitg] += sum[tpitg + i];
-        }
-        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    if (tiisg == 0) {
+        shmem_f32[sgitg] = sumf;
     }
-    const float mean  = sum[0] / ne00;
 
-    // recenter and VARIANCE
     threadgroup_barrier(mem_flags::mem_threadgroup);
-    device float * y = dst + tgpig*ne00;
-    sum[tpitg] = 0.0f;
-    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+
+    sumf = shmem_f32[tiisg];
+    sumf = simd_sum(sumf);
+
+    const float mean = sumf/args.ne00;
+
+    device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
+
+    sumf = 0.0f;
+    for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
         y[i00] = x[i00] - mean;
-        sum[tpitg] += y[i00] * y[i00];
+        sumf += dot(y[i00], y[i00]);
     }
+    sumf = simd_sum(sumf);
 
-    // reduce
     threadgroup_barrier(mem_flags::mem_threadgroup);
-    for (uint i = ntg/2; i > 0; i /= 2) {
-        if (tpitg < i) {
-            sum[tpitg] += sum[tpitg + i];
-        }
-        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    if (tiisg == 0) {
+        shmem_f32[sgitg] = sumf;
     }
-    const float variance = sum[0] / ne00;
 
-    const float scale = 1.0f/sqrt(variance + eps);
-    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    sumf = shmem_f32[tiisg];
+    sumf = simd_sum(sumf);
+
+    const float variance = sumf/args.ne00;
+
+    const float scale = 1.0f/sqrt(variance + args.eps);
+    for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
         y[i00] = y[i00] * scale;
     }
 }
 
 kernel void kernel_rms_norm(
-        device const  void * src0,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant  uint64_t & nb01,
-        constant     float & eps,
-        threadgroup float  * buf [[threadgroup(0)]],
-        uint tgpig[[threadgroup_position_in_grid]],
-        uint tpitg[[thread_position_in_threadgroup]],
-        uint sgitg[[simdgroup_index_in_threadgroup]],
-        uint tiisg[[thread_index_in_simdgroup]],
-        uint   ntg[[threads_per_threadgroup]]) {
-    device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
+        constant ggml_metal_kargs_rms_norm & args,
+        device const char * src0,
+        device       char * dst,
+        threadgroup float * shmem_f32 [[threadgroup(0)]],
+        uint   tgpig[[threadgroup_position_in_grid]],
+        ushort tpitg[[thread_position_in_threadgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort   ntg[[threads_per_threadgroup]]) {
+    if (sgitg == 0) {
+        shmem_f32[tiisg] = 0.0f;
+    }
+
+    device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
 
-    float4 sumf = 0;
-    float all_sum = 0;
+    float sumf = 0.0f;
 
     // parallel sum
-    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
-        sumf += x[i00] * x[i00];
+    for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+        sumf += dot(x[i00], x[i00]);
     }
-    all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
-    all_sum = simd_sum(all_sum);
-    if (ntg > N_SIMDWIDTH) {
-        if (sgitg == 0) {
-            buf[tiisg] = 0.0f;
-        }
+    sumf = simd_sum(sumf);
 
-        threadgroup_barrier(mem_flags::mem_threadgroup);
+    threadgroup_barrier(mem_flags::mem_threadgroup);
 
-        if (tiisg == 0) {
-            buf[sgitg] = all_sum;
-        }
+    if (tiisg == 0) {
+        shmem_f32[sgitg] = sumf;
+    }
 
-        threadgroup_barrier(mem_flags::mem_threadgroup);
+    threadgroup_barrier(mem_flags::mem_threadgroup);
 
-        all_sum = buf[tiisg];
-        all_sum = simd_sum(all_sum);
-    }
+    sumf = shmem_f32[tiisg];
+    sumf = simd_sum(sumf);
 
-    const float mean  = all_sum/ne00;
-    const float scale = 1.0f/sqrt(mean + eps);
+    const float mean  = sumf/args.ne00;
+    const float scale = 1.0f/sqrt(mean + args.eps);
 
-    device float4 * y = (device float4 *) (dst + tgpig*ne00);
-    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+    device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
+    for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
         y[i00] = x[i00] * scale;
     }
 }
@@ -1628,31 +1534,17 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
 //      quantizations where the block size is 32. It also does not
 //      guard against the number of rows not being divisible by
 //      N_DST, so this is another explicit assumption of the implementation.
-template<typename block_q_type, int nr, int nsg, int nw>
+template<typename block_q_type, int nr, int nsg, int nw, typename args_t>
 void mul_vec_q_n_f32_impl(
-        device const void  * src0,
-        device const float * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne12,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-        threadgroup int8_t * shared_values,
-                     uint3   tgpig,
-                     uint    tiisg,
-                     uint    sgitg) {
-    const int nb = ne00/QK4_0;
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+    const int nb = args.ne00/QK4_0;
 
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
@@ -1660,19 +1552,19 @@ void mul_vec_q_n_f32_impl(
 
     const int first_row = (r0 * nsg + sgitg) * nr;
 
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
 
-  //const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+  //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
 
-  //device const block_q_type * x = (device const block_q_type *) ((device char *) src0 + offset0);
-    device const float        * y = (device const float        *) ((device char *) src1 + offset1);
+  //device const block_q_type * x = (device const block_q_type *) (src0 + offset0);
+    device const float        * y = (device const float        *) (src1 + offset1);
 
     // pointers to src0 rows
     device const block_q_type * ax[nr];
     for (int row = 0; row < nr; ++row) {
-        const uint offset0 = (first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+        const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
 
         ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
     }
@@ -1680,10 +1572,10 @@ void mul_vec_q_n_f32_impl(
     float yl[16]; // src1 vector cache
     float sumf[nr] = {0.f};
 
-    const int ix = (tiisg/2);
-    const int il = (tiisg%2)*8;
+    const short ix = (tiisg/2);
+    const short il = (tiisg%2)*8;
 
-    device const float * yb = y + ix * QK4_0 + il;
+    device const float * yb = y + ix*QK4_0 + il;
 
     // each thread in a SIMD group deals with half a block.
     for (int ib = ix; ib < nb; ib += nw/2) {
@@ -1708,324 +1600,216 @@ void mul_vec_q_n_f32_impl(
         yb += QK4_0 * 16;
     }
 
+    device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
+
     for (int row = 0; row < nr; ++row) {
         const float tot = simd_sum(sumf[row]);
-        if (tiisg == 0 && first_row + row < ne01) {
-            dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot;
+
+        if (tiisg == 0 && first_row + row < args.ne01) {
+            dst_f32[first_row + row] = tot;
         }
     }
 }
 
 kernel void kernel_mul_mv_q4_0_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+    mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
 kernel void kernel_mul_mv_q4_1_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]],
-        uint sgitg[[simdgroup_index_in_threadgroup]]) {
-     mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+     mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
 kernel void kernel_mul_mv_q5_0_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+    mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
 kernel void kernel_mul_mv_q5_1_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+    mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
-
 #define NB_Q8_0 8
 
+template<typename args_t>
 void kernel_mul_mv_q8_0_f32_impl(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne12,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-        threadgroup int8_t * shared_values,
-                   uint3     tgpig,
-                   uint      tiisg,
-                   uint      sgitg) {
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
     const int nr  = N_DST;
     const int nsg = N_SIMDGROUP;
     const int nw  = N_SIMDWIDTH;
 
-    const int nb = ne00/QK8_0;
+    const int nb = args.ne00/QK8_0;
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
-    const int first_row = (r0 * nsg + sgitg) * nr;
+    const int first_row = (r0*nsg + sgitg)*nr;
 
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
 
-  //const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+  //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
 
-  //device const block_q8_0 * x = (device const block_q8_0 *) ((device char *) src0 + offset0);
-    device const float      * y = (device const float      *) ((device char *) src1 + offset1);
+  //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0);
+    device const float      * y = (device const float      *) (src1 + offset1);
 
     // pointers to src0 rows
     device const block_q8_0 * ax[nr];
     for (int row = 0; row < nr; ++row) {
-        const uint offset0 = (first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+        const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
 
         ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
     }
 
     float yl[NB_Q8_0];
-    float sumf[nr]={0.f};
+    float sumf[nr] = { 0.f };
 
-    const int ix = tiisg/4;
-    const int il = tiisg%4;
+    const short ix = tiisg/4;
+    const short il = tiisg%4;
 
-    device const float * yb = y + ix * QK8_0 + NB_Q8_0*il;
+    device const float * yb = y + ix*QK8_0 + il*NB_Q8_0;
 
     // each thread in a SIMD group deals with NB_Q8_0 quants at a time
     for (int ib = ix; ib < nb; ib += nw/4) {
-        for (int i = 0; i < NB_Q8_0; ++i) {
+        for (short i = 0; i < NB_Q8_0; ++i) {
             yl[i] = yb[i];
         }
 
         for (int row = 0; row < nr; row++) {
-            device const int8_t * qs = ax[row][ib].qs + NB_Q8_0*il;
+            device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0;
             float sumq = 0.f;
-            for (int iq = 0; iq < NB_Q8_0; ++iq) {
+            for (short iq = 0; iq < NB_Q8_0; ++iq) {
                 sumq += qs[iq] * yl[iq];
             }
             sumf[row] += sumq*ax[row][ib].d;
         }
 
-        yb += NB_Q8_0 * nw;
+        yb += nw*NB_Q8_0;
     }
 
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
     for (int row = 0; row < nr; ++row) {
         const float tot = simd_sum(sumf[row]);
-        if (tiisg == 0 && first_row + row < ne01) {
-            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
+
+        if (tiisg == 0 && first_row + row < args.ne01) {
+            dst_f32[first_row + row] = tot;
         }
     }
 }
 
 [[host_name("kernel_mul_mv_q8_0_f32")]]
 kernel void kernel_mul_mv_q8_0_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-    kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+    kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
 #define N_MV_T_T 4
 
-template<typename T0, typename T04, typename T1, typename T14>
+template<typename T0, typename T04, typename T1, typename T14, typename args_t>
 void kernel_mul_mv_impl(
-        device const  char * src0,
-        device const  char * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb00,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne11,
-                   int64_t   ne12,
-                  uint64_t   nb10,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-                   uint3     tgpig,
-                   uint      tiisg) {
-    const int64_t r0 = tgpig.x;
-    const int64_t rb = tgpig.y*N_MV_T_T;
-    const int64_t im = tgpig.z;
-
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
-
-    const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig,
+        ushort tiisg) {
+    const int r0 = tgpig.x;
+    const int rb = tgpig.y*N_MV_T_T;
+    const int im = tgpig.z;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
 
     device const T0 * x = (device const T0 *) (src0 + offset0);
 
-    if (ne00 < 128) {
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
+
+    if (args.ne00 < 128) {
         for (int row = 0; row < N_MV_T_T; ++row) {
             int r1 = rb + row;
-            if (r1 >= ne11) {
+            if (r1 >= args.ne11) {
                 break;
             }
 
-            const uint offset1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+            const uint64_t offset1 = r1*args.nb11 + (i12   )*args.nb12 + (i13   )*args.nb13;
 
             device const T1 * y = (device const T1 *) (src1 + offset1);
 
             float sumf = 0;
-            for (int i = tiisg; i < ne00; i += 32) {
+            for (int i = tiisg; i < args.ne00; i += 32) {
                 sumf += (T0) x[i] * (T1) y[i];
             }
 
             float all_sum = simd_sum(sumf);
             if (tiisg == 0) {
-                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+                dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
             }
         }
     } else {
         device const T04 * x4 = (device const T04 *) x;
         for (int row = 0; row < N_MV_T_T; ++row) {
             int r1 = rb + row;
-            if (r1 >= ne11) {
+            if (r1 >= args.ne11) {
                 break;
             }
 
-            const uint offset1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+            const uint64_t offset1 = r1*args.nb11 + (i12   )*args.nb12 + (i13   )*args.nb13;
 
             device const T1  * y  = (device const T1  *) (src1 + offset1);
             device const T14 * y4 = (device const T14 *) y;
 
             float sumf = 0;
-            for (int i = tiisg; i < ne00/4; i += 32) {
-                for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
+            for (int i = tiisg; i < args.ne00/4; i += 32) {
+                sumf += dot((float4) x4[i], (float4) y4[i]);
             }
 
             float all_sum = simd_sum(sumf);
             if (tiisg == 0) {
-                for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]);
-                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+                for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]);
+                dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
             }
         }
     }
@@ -2033,51 +1817,17 @@ void kernel_mul_mv_impl(
 
 template<typename T0, typename T04, typename T1, typename T14>
 kernel void kernel_mul_mv(
-        device const  char * src0,
-        device const  char * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]]) {
-    kernel_mul_mv_impl<T0, T04, T1, T14>(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]]) {
+    kernel_mul_mv_impl<T0, T04, T1, T14, constant ggml_metal_kargs_mul_mv &>(
+        args,
         src0,
         src1,
         dst,
-        ne00,
-        ne01,
-        ne02,
-        nb00,
-        nb01,
-        nb02,
-        nb03,
-        ne10,
-        ne11,
-        ne12,
-        nb10,
-        nb11,
-        nb12,
-        nb13,
-        ne0,
-        ne1,
-        r2,
-        r3,
         tgpig,
         tiisg);
 }
@@ -2094,65 +1844,50 @@ template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<
 
 template<typename T, typename T4>
 kernel void kernel_mul_mv_1row(
-        device const  char * src0,
-        device const  char * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]]) {
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]]) {
 
-    const int64_t r0 = tgpig.x;
-    const int64_t r1 = tgpig.y;
-    const int64_t im = tgpig.z;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
 
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
 
-    const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-    const uint offset1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+    const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 = r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
 
     device const T     * x = (device const T     *) (src0 + offset0);
     device const float * y = (device const float *) (src1 + offset1);
 
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
     float sumf = 0;
-    if (ne00 < 128) {
-        for (int i = tiisg; i < ne00; i += 32) {
+    if (args.ne00 < 128) {
+        for (int i = tiisg; i < args.ne00; i += 32) {
             sumf += (float) x[i] * (float) y[i];
         }
         float all_sum = simd_sum(sumf);
         if (tiisg == 0) {
-            dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            dst_f32[r0] = all_sum;
         }
     } else {
         device const T4     * x4 = (device const T4     *) x;
         device const float4 * y4 = (device const float4 *) y;
 
-        for (int i = tiisg; i < ne00/4; i += 32) {
-            for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
+        for (int i = tiisg; i < args.ne00/4; i += 32) {
+            sumf += dot((float4) x4[i], y4[i]);
         }
 
         float all_sum = simd_sum(sumf);
 
         if (tiisg == 0) {
-            for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]);
-            dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]);
+            dst_f32[r0] = all_sum;
         }
     }
 }
@@ -2167,54 +1902,39 @@ template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kerne
 // Assumes row size (ne00) is a multiple of 4
 template<typename T, typename T4>
 kernel void kernel_mul_mv_l4(
-        device const  char * src0,
-        device const  char * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]]) {
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]]) {
 
-    const int nrows = ne11;
-    const int64_t r0 = tgpig.x;
-    const int64_t im = tgpig.z;
+    const int nrows = args.ne11;
+    const int r0 = tgpig.x;
+    const int im = tgpig.z;
 
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
 
-    const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
 
     device const T4 * x4 = (device const T4 *) (src0 + offset0);
 
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
+
     for (int r1 = 0; r1 < nrows; ++r1) {
-        const uint offset1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+        const uint64_t offset1 = r1*args.nb11 + (i12   )*args.nb12 + (i13   )*args.nb13;
 
         device const float4 * y4 = (device const float4 *) (src1 + offset1);
 
         float sumf = 0;
-        for (int i = tiisg; i < ne00/4; i += 32) {
-            for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
+        for (int i = tiisg; i < args.ne00/4; i += 32) {
+            sumf += dot((float4) x4[i], y4[i]);
         }
 
         float all_sum = simd_sum(sumf);
         if (tiisg == 0) {
-            dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
         }
     }
 }
@@ -2234,7 +1954,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) {
 // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
 // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
 static void rope_yarn(
-    float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
+    float theta_extrap, float freq_scale, float corr_dims[2], int i0, float ext_factor, float mscale,
     thread float * cos_theta, thread float * sin_theta) {
     // Get n-d rotational scaling corrected for extrapolation
     float theta_interp = freq_scale * theta_extrap;
@@ -2266,65 +1986,41 @@ static void rope_yarn_corr_dims(
 
 template<typename T>
 kernel void kernel_rope_norm(
-        device const    void * src0,
-        device const int32_t * src1,
-        device const   float * src2,
-        device         float * dst,
-        constant     int64_t & ne00,
-        constant     int64_t & ne01,
-        constant     int64_t & ne02,
-        constant     int64_t & ne03,
-        constant    uint64_t & nb00,
-        constant    uint64_t & nb01,
-        constant    uint64_t & nb02,
-        constant    uint64_t & nb03,
-        constant     int64_t & ne0,
-        constant     int64_t & ne1,
-        constant     int64_t & ne2,
-        constant     int64_t & ne3,
-        constant    uint64_t & nb0,
-        constant    uint64_t & nb1,
-        constant    uint64_t & nb2,
-        constant    uint64_t & nb3,
-        constant         int & n_past,
-        constant         int & n_dims,
-        constant         int & n_ctx_orig,
-        constant       float & freq_base,
-        constant       float & freq_scale,
-        constant       float & ext_factor,
-        constant       float & attn_factor,
-        constant       float & beta_fast,
-        constant       float & beta_slow,
-        uint  tiitg[[thread_index_in_threadgroup]],
-        uint3 tptg[[threads_per_threadgroup]],
-        uint3 tgpig[[threadgroup_position_in_grid]]) {
-    const int64_t i3 = tgpig[2];
-    const int64_t i2 = tgpig[1];
-    const int64_t i1 = tgpig[0];
+        constant ggml_metal_kargs_rope & args,
+        device const char * src0,
+        device const char * src1,
+        device const char * src2,
+        device       char * dst,
+        ushort  tiitg[[thread_index_in_threadgroup]],
+        ushort3 tptg [[threads_per_threadgroup]],
+        uint3   tgpig[[threadgroup_position_in_grid]]) {
+    const int i3 = tgpig[2];
+    const int i2 = tgpig[1];
+    const int i1 = tgpig[0];
 
     float corr_dims[2];
-    rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
+    rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
 
-    device const int32_t * pos = src1;
+    device const int32_t * pos = (device const int32_t *) src1;
 
     const float theta_base = (float) pos[i2];
-    const float inv_ndims = -1.f/n_dims;
+    const float inv_ndims = -1.f/args.n_dims;
 
     float cos_theta;
     float sin_theta;
 
-    for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
-        if (i0 < n_dims) {
-            const int64_t ic = i0/2;
+    for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
+        if (i0 < args.n_dims) {
+            const int ic = i0/2;
 
-            const float theta = theta_base * pow(freq_base, inv_ndims*i0);
+            const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
 
-            const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
+            const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
 
-            rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+            rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
 
-            device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-            device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
+            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + i0*args.nb0);
 
             const float x0 = src[0];
             const float x1 = src[1];
@@ -2332,8 +2028,8 @@ kernel void kernel_rope_norm(
             dst_data[0] = x0*cos_theta - x1*sin_theta;
             dst_data[1] = x0*sin_theta + x1*cos_theta;
         } else {
-            device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-            device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
+            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + i0*args.nb0);
 
             dst_data[0] = src[0];
             dst_data[1] = src[1];
@@ -2343,74 +2039,50 @@ kernel void kernel_rope_norm(
 
 template<typename T>
 kernel void kernel_rope_neox(
-        device const    void * src0,
-        device const int32_t * src1,
-        device const   float * src2,
-        device         float * dst,
-        constant     int64_t & ne00,
-        constant     int64_t & ne01,
-        constant     int64_t & ne02,
-        constant     int64_t & ne03,
-        constant    uint64_t & nb00,
-        constant    uint64_t & nb01,
-        constant    uint64_t & nb02,
-        constant    uint64_t & nb03,
-        constant     int64_t & ne0,
-        constant     int64_t & ne1,
-        constant     int64_t & ne2,
-        constant     int64_t & ne3,
-        constant    uint64_t & nb0,
-        constant    uint64_t & nb1,
-        constant    uint64_t & nb2,
-        constant    uint64_t & nb3,
-        constant         int & n_past,
-        constant         int & n_dims,
-        constant         int & n_ctx_orig,
-        constant       float & freq_base,
-        constant       float & freq_scale,
-        constant       float & ext_factor,
-        constant       float & attn_factor,
-        constant       float & beta_fast,
-        constant       float & beta_slow,
-        uint  tiitg[[thread_index_in_threadgroup]],
-        uint3 tptg[[threads_per_threadgroup]],
-        uint3 tgpig[[threadgroup_position_in_grid]]) {
-    const int64_t i3 = tgpig[2];
-    const int64_t i2 = tgpig[1];
-    const int64_t i1 = tgpig[0];
+        constant ggml_metal_kargs_rope & args,
+        device const char * src0,
+        device const char * src1,
+        device const char * src2,
+        device       char * dst,
+        ushort  tiitg[[thread_index_in_threadgroup]],
+        ushort3 tptg [[threads_per_threadgroup]],
+        uint3   tgpig[[threadgroup_position_in_grid]]) {
+    const int i3 = tgpig[2];
+    const int i2 = tgpig[1];
+    const int i1 = tgpig[0];
 
     float corr_dims[2];
-    rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
+    rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
 
-    device const int32_t * pos = src1;
+    device const int32_t * pos = (device const int32_t *) src1;
 
     const float theta_base = (float) pos[i2];
-    const float inv_ndims = -1.f/n_dims;
+    const float inv_ndims = -1.f/args.n_dims;
 
     float cos_theta;
     float sin_theta;
 
-    for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
-        if (i0 < n_dims) {
-            const int64_t ic = i0/2;
+    for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
+        if (i0 < args.n_dims) {
+            const int ic = i0/2;
 
-            const float theta = theta_base * pow(freq_base, inv_ndims*i0);
+            const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
 
-            const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
+            const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
 
-            rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+            rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
 
-            device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
-            device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + ic*nb0);
+            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
+            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + ic*args.nb0);
 
             const float x0 = src[0];
-            const float x1 = src[n_dims/2];
+            const float x1 = src[args.n_dims/2];
 
-            dst_data[0]        = x0*cos_theta - x1*sin_theta;
-            dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
+            dst_data[0]             = x0*cos_theta - x1*sin_theta;
+            dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
         } else {
-            device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
-            device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0);
+            device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
+            device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + i0*args.nb0);
 
             dst_data[0] = src[0];
             dst_data[1] = src[1];
@@ -2808,37 +2480,17 @@ template<
     short KV = 8,    // key/value processed per each simdgroup
     short C  = 32>   // cache items per threadgroup
 kernel void kernel_flash_attn_ext(
-        device const  char * q,
-        device const  char * k,
-        device const  char * v,
-        device const  char * mask,
-        device       float * dst,
-        constant   int32_t & ne01,
-        constant   int32_t & ne02,
-        constant   int32_t & ne03,
-        constant  uint32_t & nb01,
-        constant  uint32_t & nb02,
-        constant  uint32_t & nb03,
-        constant   int32_t & ne11,
-        constant   int32_t & ne_12_2, // assume K and V are same shape
-        constant   int32_t & ne_12_3,
-        constant  uint32_t & nb_12_1,
-        constant  uint32_t & nb_12_2,
-        constant  uint32_t & nb_12_3,
-        constant  uint32_t & nb31,
-        constant   int32_t & ne1,
-        constant   int32_t & ne2,
-        constant     float & scale,
-        constant     float & max_bias,
-        constant     float & m0,
-        constant     float & m1,
-        constant  uint16_t & n_head_log2,
-        constant     float & logit_softcap,
-        threadgroup   half * shared [[threadgroup(0)]],
-        ushort3  tgpig[[threadgroup_position_in_grid]],
-        ushort3    ntg[[threads_per_threadgroup]],
-        ushort   tiisg[[thread_index_in_simdgroup]],
-        ushort   sgitg[[simdgroup_index_in_threadgroup]]) {
+        constant ggml_metal_kargs_flash_attn_ext & args,
+        device const char * q,
+        device const char * k,
+        device const char * v,
+        device const char * mask,
+        device       char * dst,
+        threadgroup  half * shmem_f16 [[threadgroup(0)]],
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3   ntg[[threads_per_threadgroup]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
     const short nsg = ntg.y; // number of simdgroups
 
     const int iq3 = tgpig[2];
@@ -2854,27 +2506,27 @@ kernel void kernel_flash_attn_ext(
     const short TS = nsg*SH;   // shared memory size per query in (s_t == float)
     const short T  = D + 2*TS; // shared memory size per query in (half)
 
-    threadgroup q_t  * sq  = (threadgroup q_t  *) (shared +              0*D); // holds the query data
-    threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared +              0*D); // same as above but in q4_t
-    threadgroup o_t  * so  = (threadgroup o_t  *) (shared +              0*D); // reuse query data for accumulation
-    threadgroup o4_t * so4 = (threadgroup o4_t *) (shared +              0*D); // same as above but in o4_t
-    threadgroup s_t  * ss  = (threadgroup s_t  *) (shared + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix
+    threadgroup q_t  * sq  = (threadgroup q_t  *) (shmem_f16 +              0*D); // holds the query data
+    threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 +              0*D); // same as above but in q4_t
+    threadgroup o_t  * so  = (threadgroup o_t  *) (shmem_f16 +              0*D); // reuse query data for accumulation
+    threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 +              0*D); // same as above but in o4_t
+    threadgroup s_t  * ss  = (threadgroup s_t  *) (shmem_f16 + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix
 
-    threadgroup k_t    * sk    = (threadgroup k_t    *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
-    threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
+    threadgroup k_t    * sk    = (threadgroup k_t    *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
+    threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
 
-    threadgroup v_t    * sv    = (threadgroup v_t    *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory
-    threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
+    threadgroup v_t    * sv    = (threadgroup v_t    *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory
+    threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
 
     // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
     o8x8_t lo[D8];
 
     // load heads from Q to shared memory
     for (short j = sgitg; j < Q; j += nsg) {
-        device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
+        device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
 
         for (short i = tiisg; i < D4; i += NW) {
-            if (iq1 + j < ne01) {
+            if (iq1 + j < args.ne01) {
                 sq4[j*D4 + i] = (q4_t) q4[i];
             } else {
                 sq4[j*D4 + i] = (q4_t) 0.0f;
@@ -2907,11 +2559,11 @@ kernel void kernel_flash_attn_ext(
         const short ty = tiisg/4;
 
         // broadcast kv
-        //const short rk2 = ne02/ne12;
-        //const short rk3 = ne03/ne13;
+        //const short rk2 = args.ne02/args.ne12;
+        //const short rk3 = args.ne03/args.ne13;
 
-        const short ikv2 = iq2/(ne02/ne_12_2);
-        const short ikv3 = iq3/(ne03/ne_12_3);
+        const short ikv2 = iq2/(args.ne02/args.ne_12_2);
+        const short ikv3 = iq3/(args.ne03/args.ne_12_3);
 
         // load the queries from shared memory into local memory
         q8x8_t mq[D8];
@@ -2925,20 +2577,20 @@ kernel void kernel_flash_attn_ext(
         half slope = 1.0f;
 
         // ALiBi
-        if (max_bias > 0.0f) {
+        if (args.max_bias > 0.0f) {
             const short h = iq2;
 
-            const half  base = h < n_head_log2 ? m0 : m1;
-            const short exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+            const half  base = h < args.n_head_log2 ? args.m0 : args.m1;
+            const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
 
             slope = pow(base, exph);
         }
 
         // loop over the KV cache
         // each simdgroup handles blocks of Q rows and C columns
-        for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
+        for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) {
             const int ic = ic0 + C*sgitg;
-            if (ic >= ne11) {
+            if (ic >= args.ne11) {
                 break;
             }
 
@@ -2949,7 +2601,7 @@ kernel void kernel_flash_attn_ext(
                 // load the mask in shared memory
                 #pragma unroll(Q)
                 for (short j = 0; j < Q; ++j) {
-                    device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*nb31);
+                    device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
 
                     const half m = pm[ic + tiisg];
 
@@ -2972,18 +2624,18 @@ kernel void kernel_flash_attn_ext(
                     // this is compile-time check, so it does not have runtime overhead
                     if (is_same<kd4x4_t, k4x4_t>::value) {
                         // we can read directly from global memory
-                        device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
+                        device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
 
                         #pragma unroll(D8)
                         for (short i = 0; i < D8; ++i) {
                             k8x8_t mk;
-                            simdgroup_load(mk, pk + i*8, nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10
+                            simdgroup_load(mk, pk + i*8, args.nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10
 
                             simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
                         }
                     } else {
                         for (short ii = 0; ii < D16; ii += 4) {
-                            device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
+                            device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
 
                             if (D16%4 == 0) {
                                 // the head is evenly divisible by 4*16 = 64, so no need for bound checks
@@ -3042,10 +2694,10 @@ kernel void kernel_flash_attn_ext(
                     const half m = M[j];
 
                     // scale and apply the logitcap / mask
-                    half s = ss[j*TS + tiisg]*scale;
+                    half s = ss[j*TS + tiisg]*args.scale;
 
-                    if (logit_softcap != 0.0f) {
-                        s = logit_softcap*precise::tanh(s);
+                    if (args.logit_softcap != 0.0f) {
+                        s = args.logit_softcap*precise::tanh(s);
                     }
 
                     // mqk = mqk + mask*slope
@@ -3087,18 +2739,18 @@ kernel void kernel_flash_attn_ext(
 
                     if (is_same<vd4x4_t, v4x4_t>::value) {
                         // we can read directly from global memory
-                        device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
+                        device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
 
                         #pragma unroll(D8)
                         for (short i = 0; i < D8; ++i) {
                             v8x8_t mv;
-                            simdgroup_load(mv, pv + i*8, nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20
+                            simdgroup_load(mv, pv + i*8, args.nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20
 
                             simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
                         }
                     } else {
                         for (short ii = 0; ii < D16; ii += 4) {
-                            device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
+                            device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
 
                             if (D16%4 == 0) {
                                 // no need for bound checks
@@ -3227,11 +2879,11 @@ kernel void kernel_flash_attn_ext(
 
     // final rescale with 1/S and store to global memory
     if (sgitg == 0) {
-        for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
+        for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) {
             const float S = ss[j*TS + 0];
 
             for (short i = tiisg; i < D4; i += NW) {
-                dst4[((int64_t)iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
+                dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
             }
         }
     }
@@ -3323,38 +2975,17 @@ template<
     short Q  = 1,    // queries per threadgroup
     short C  = 32>   // cache items per threadgroup
 kernel void kernel_flash_attn_ext_vec(
-        device const  char * q,
-        device const  char * k,
-        device const  char * v,
-        device const  char * mask,
-        device       float * dst,
-        constant   int32_t & ne01,
-        constant   int32_t & ne02,
-        constant   int32_t & ne03,
-        constant  uint32_t & nb01,
-        constant  uint32_t & nb02,
-        constant  uint32_t & nb03,
-        constant   int32_t & ne11,
-        constant   int32_t & ne_12_2, // assume K and V are same shape
-        constant   int32_t & ne_12_3,
-        constant  uint32_t & nb_12_1,
-        constant  uint32_t & nb_12_2,
-        constant  uint32_t & nb_12_3,
-        constant  uint32_t & nb31,
-        constant   int32_t & ne1,
-        constant   int32_t & ne2,
-        constant     float & scale,
-        constant     float & max_bias,
-        constant     float & m0,
-        constant     float & m1,
-        constant  uint16_t & n_head_log2,
-        constant     float & logit_softcap,
-        threadgroup   half * shared [[threadgroup(0)]],
-        ushort3  tgpig[[threadgroup_position_in_grid]],
-        ushort3  tpitg[[thread_position_in_threadgroup]],
-        ushort3    ntg[[threads_per_threadgroup]],
-        ushort   tiisg[[thread_index_in_simdgroup]],
-        ushort   sgitg[[simdgroup_index_in_threadgroup]]) {
+        constant ggml_metal_kargs_flash_attn_ext & args,
+        device const char * q,
+        device const char * k,
+        device const char * v,
+        device const char * mask,
+        device       char * dst,
+        threadgroup  half * shmem_f16 [[threadgroup(0)]],
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3   ntg[[threads_per_threadgroup]],
+        ushort  tiisg[[thread_index_in_simdgroup]],
+        ushort  sgitg[[simdgroup_index_in_threadgroup]]) {
     const short nsg = ntg.y; // number of simdgroups
 
     const int iq3 = tgpig[2];
@@ -3369,22 +3000,22 @@ kernel void kernel_flash_attn_ext_vec(
 
     const short T = D + nsg*SH; // shared memory size per query in (half)
 
-  //threadgroup q_t    * sq    = (threadgroup q_t    *) (shared +                0*D); // holds the query data
-    threadgroup q4_t   * sq4   = (threadgroup q4_t   *) (shared +                0*D); // same as above but in q4_t
-    threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shared +                0*D); // same as above but in q4x4_t
-    threadgroup s_t    * ss    = (threadgroup s_t    *) (shared + sgitg*SH     + Q*D); // scratch buffer for attention
-    threadgroup s4_t   * ss4   = (threadgroup s4_t   *) (shared + sgitg*SH     + Q*D); // same as above but in s4_t
-    threadgroup half   * sm    = (threadgroup half   *) (shared + sgitg*SH + C + Q*D); // scratch buffer for mask
-    threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D      + Q*T); // scratch buffer for the results
+  //threadgroup q_t    * sq    = (threadgroup q_t    *) (shmem_f16 +                0*D); // holds the query data
+    threadgroup q4_t   * sq4   = (threadgroup q4_t   *) (shmem_f16 +                0*D); // same as above but in q4_t
+    threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shmem_f16 +                0*D); // same as above but in q4x4_t
+    threadgroup s_t    * ss    = (threadgroup s_t    *) (shmem_f16 + sgitg*SH     + Q*D); // scratch buffer for attention
+    threadgroup s4_t   * ss4   = (threadgroup s4_t   *) (shmem_f16 + sgitg*SH     + Q*D); // same as above but in s4_t
+    threadgroup half   * sm    = (threadgroup half   *) (shmem_f16 + sgitg*SH + C + Q*D); // scratch buffer for mask
+    threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shmem_f16 + sgitg*D      + Q*T); // scratch buffer for the results
 
     // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
     o4x4_t lo[D16/NL];
 
     // load heads from Q to shared memory
-    device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
+    device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
 
     for (short i = tiisg; i < D4; i += NW) {
-        if (iq1 < ne01) {
+        if (iq1 < args.ne01) {
             sq4[i] = (q4_t) q4[i];
         } else {
             sq4[i] = (q4_t) 0.0f;
@@ -3412,11 +3043,11 @@ kernel void kernel_flash_attn_ext_vec(
         const short ty = tiisg/NL;
 
         // broadcast kv
-        //const short rk2 = ne02/ne12;
-        //const short rk3 = ne03/ne13;
+        //const short rk2 = args.ne02/args.ne12;
+        //const short rk3 = args.ne03/args.ne13;
 
-        const short ikv2 = iq2/(ne02/ne_12_2);
-        const short ikv3 = iq3/(ne03/ne_12_3);
+        const short ikv2 = iq2/(args.ne02/args.ne_12_2);
+        const short ikv3 = iq3/(args.ne03/args.ne_12_3);
 
         // load the queries from shared memory into local memory
         q4x4_t mq[D16/NL];
@@ -3429,25 +3060,25 @@ kernel void kernel_flash_attn_ext_vec(
         const bool has_mask = mask != q;
 
         // pointer to the mask
-        device const half * pm = (device const half *) (mask + iq1*nb31);
+        device const half * pm = (device const half *) (mask + iq1*args.nb31);
 
         half slope = 1.0f;
 
         // ALiBi
-        if (max_bias > 0.0f) {
+        if (args.max_bias > 0.0f) {
             const short h = iq2;
 
-            const half  base = h < n_head_log2 ? m0 : m1;
-            const short exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+            const half  base = h < args.n_head_log2 ? args.m0 : args.m1;
+            const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
 
             slope = pow(base, exph);
         }
 
         // loop over the KV cache
         // each simdgroup handles blocks of Q rows and C columns
-        for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
+        for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) {
             const int ic = ic0 + C*sgitg;
-            if (ic >= ne11) {
+            if (ic >= args.ne11) {
                 break;
             }
 
@@ -3461,7 +3092,7 @@ kernel void kernel_flash_attn_ext_vec(
                 for (short cc = 0; cc < C/4; ++cc) {
                     qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 };
 
-                    device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
+                    device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
 
                     #pragma unroll(D16/NL)
                     for (short ii = 0; ii < D16; ii += NL) {
@@ -3497,10 +3128,10 @@ kernel void kernel_flash_attn_ext_vec(
 
                     // mqk = mqk*scale + mask*slope
                     if (tx == 0) {
-                        mqk *= scale;
+                        mqk *= args.scale;
 
-                        if (logit_softcap != 0.0f) {
-                            mqk = logit_softcap*precise::tanh(mqk);
+                        if (args.logit_softcap != 0.0f) {
+                            mqk = args.logit_softcap*precise::tanh(mqk);
                         }
 
                         mqk += sm[4*cc + ty]*slope;
@@ -3539,7 +3170,7 @@ kernel void kernel_flash_attn_ext_vec(
             // O = O + (Q*K^T)*V
             {
                 for (short cc = 0; cc < C/4; ++cc) {
-                    device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
+                    device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
 
                     const s4x4_t ms(ss[4*cc + ty]);
 
@@ -3644,7 +3275,7 @@ kernel void kernel_flash_attn_ext_vec(
         const float S = ss[0];
 
         for (short i = tiisg; i < D16; i += NW) {
-            dst44[((int64_t)iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
+            dst44[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
         }
     }
 }
@@ -3686,42 +3317,27 @@ template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_
 
 template<typename T0, typename T1>
 kernel void kernel_cpy(
-        device  const void * src0,
-        device        void * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne03,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   int64_t & ne2,
-        constant   int64_t & ne3,
-        constant  uint64_t & nb0,
-        constant  uint64_t & nb1,
-        constant  uint64_t & nb2,
-        constant  uint64_t & nb3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = tgpig[2];
-    const int64_t i02 = tgpig[1];
-    const int64_t i01 = tgpig[0];
+        constant ggml_metal_kargs_cpy & args,
+        device  const char * src0,
+        device        char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig[2];
+    const int i02 = tgpig[1];
+    const int i01 = tgpig[0];
 
-    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
 
-    const int64_t i3 = n / (ne2*ne1*ne0);
-    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
-    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
-    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+    const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
+    const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
+    const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
+    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
 
-    device T1 * dst_data = (device T1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+    device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
 
-    for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
-        device const T0 * src = (device T0 *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+    for (int64_t i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
+        device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
         dst_data[i00] = (T1) src[0];
     }
 }
@@ -3741,42 +3357,27 @@ template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy<bf
 #endif
 
 kernel void kernel_cpy_f32_q8_0(
-        device const float * src0,
-        device        void * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne03,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   int64_t & ne2,
-        constant   int64_t & ne3,
-        constant  uint64_t & nb0,
-        constant  uint64_t & nb1,
-        constant  uint64_t & nb2,
-        constant  uint64_t & nb3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = tgpig[2];
-    const int64_t i02 = tgpig[1];
-    const int64_t i01 = tgpig[0];
+        constant ggml_metal_kargs_cpy & args,
+        device const char * src0,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig[2];
+    const int i02 = tgpig[1];
+    const int i01 = tgpig[0];
 
-    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
 
-    const int64_t i3 = n / (ne2*ne1*ne0);
-    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
-    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
-    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0;
+    const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
+    const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
+    const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
+    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK8_0;
 
-    device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+    device block_q8_0 * dst_data = (device block_q8_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
 
-    for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) {
-        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+    for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) {
+        device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
 
         float amax = 0.0f; // absolute max
 
@@ -3799,42 +3400,27 @@ kernel void kernel_cpy_f32_q8_0(
 }
 
 kernel void kernel_cpy_f32_q4_0(
-        device const float * src0,
-        device        void * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne03,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   int64_t & ne2,
-        constant   int64_t & ne3,
-        constant  uint64_t & nb0,
-        constant  uint64_t & nb1,
-        constant  uint64_t & nb2,
-        constant  uint64_t & nb3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = tgpig[2];
-    const int64_t i02 = tgpig[1];
-    const int64_t i01 = tgpig[0];
+        constant ggml_metal_kargs_cpy & args,
+        device const char * src0,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig[2];
+    const int i02 = tgpig[1];
+    const int i01 = tgpig[0];
 
-    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
 
-    const int64_t i3 = n / (ne2*ne1*ne0);
-    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
-    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
-    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0;
+    const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
+    const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
+    const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
+    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_0;
 
-    device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+    device block_q4_0 * dst_data = (device block_q4_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
 
-    for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) {
-        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+    for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) {
+        device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
 
         float amax = 0.0f; // absolute max
         float max  = 0.0f;
@@ -3866,42 +3452,27 @@ kernel void kernel_cpy_f32_q4_0(
 }
 
 kernel void kernel_cpy_f32_q4_1(
-        device const float * src0,
-        device        void * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne03,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   int64_t & ne2,
-        constant   int64_t & ne3,
-        constant  uint64_t & nb0,
-        constant  uint64_t & nb1,
-        constant  uint64_t & nb2,
-        constant  uint64_t & nb3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = tgpig[2];
-    const int64_t i02 = tgpig[1];
-    const int64_t i01 = tgpig[0];
+        constant ggml_metal_kargs_cpy & args,
+        device const char * src0,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig[2];
+    const int i02 = tgpig[1];
+    const int i01 = tgpig[0];
 
-    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
 
-    const int64_t i3 = n / (ne2*ne1*ne0);
-    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
-    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
-    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1;
+    const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
+    const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
+    const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
+    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_1;
 
-    device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+    device block_q4_1 * dst_data = (device block_q4_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
 
-    for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) {
-        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+    for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) {
+        device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
 
         float min = FLT_MAX;
         float max = -FLT_MAX;
@@ -3932,42 +3503,27 @@ kernel void kernel_cpy_f32_q4_1(
 }
 
 kernel void kernel_cpy_f32_q5_0(
-        device const float * src0,
-        device        void * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne03,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   int64_t & ne2,
-        constant   int64_t & ne3,
-        constant  uint64_t & nb0,
-        constant  uint64_t & nb1,
-        constant  uint64_t & nb2,
-        constant  uint64_t & nb3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = tgpig[2];
-    const int64_t i02 = tgpig[1];
-    const int64_t i01 = tgpig[0];
+        constant ggml_metal_kargs_cpy & args,
+        device const char * src0,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig[2];
+    const int i02 = tgpig[1];
+    const int i01 = tgpig[0];
 
-    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
 
-    const int64_t i3 = n / (ne2*ne1*ne0);
-    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
-    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
-    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_0;
+    const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
+    const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
+    const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
+    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_0;
 
-    device block_q5_0 * dst_data = (device block_q5_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+    device block_q5_0 * dst_data = (device block_q5_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
 
-    for (int64_t i00 = tpitg.x*QK5_0; i00 < ne00; i00 += ntg.x*QK5_0) {
-        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+    for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) {
+        device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
 
         float amax = 0.0f; // absolute max
         float max  = 0.0f;
@@ -4005,42 +3561,27 @@ kernel void kernel_cpy_f32_q5_0(
 }
 
 kernel void kernel_cpy_f32_q5_1(
-        device const float * src0,
-        device        void * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne03,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   int64_t & ne2,
-        constant   int64_t & ne3,
-        constant  uint64_t & nb0,
-        constant  uint64_t & nb1,
-        constant  uint64_t & nb2,
-        constant  uint64_t & nb3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = tgpig[2];
-    const int64_t i02 = tgpig[1];
-    const int64_t i01 = tgpig[0];
+        constant ggml_metal_kargs_cpy & args,
+        device const char * src0,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig[2];
+    const int i02 = tgpig[1];
+    const int i01 = tgpig[0];
 
-    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
 
-    const int64_t i3 = n / (ne2*ne1*ne0);
-    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
-    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
-    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_1;
+    const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
+    const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
+    const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
+    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_1;
 
-    device block_q5_1 * dst_data = (device block_q5_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+    device block_q5_1 * dst_data = (device block_q5_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
 
-    for (int64_t i00 = tpitg.x*QK5_1; i00 < ne00; i00 += ntg.x*QK5_1) {
-        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+    for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) {
+        device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
 
         float max = src[0];
         float min = src[0];
@@ -4088,42 +3629,27 @@ static inline int best_index_int8(int n, constant float * val, float x) {
 }
 
 kernel void kernel_cpy_f32_iq4_nl(
-        device const float * src0,
-        device        void * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne03,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   int64_t & ne2,
-        constant   int64_t & ne3,
-        constant  uint64_t & nb0,
-        constant  uint64_t & nb1,
-        constant  uint64_t & nb2,
-        constant  uint64_t & nb3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = tgpig[2];
-    const int64_t i02 = tgpig[1];
-    const int64_t i01 = tgpig[0];
+        constant ggml_metal_kargs_cpy & args,
+        device const char * src0,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i03 = tgpig[2];
+    const int i02 = tgpig[1];
+    const int i01 = tgpig[0];
 
-    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+    const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
 
-    const int64_t i3 = n / (ne2*ne1*ne0);
-    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
-    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
-    const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_NL;
+    const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
+    const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
+    const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
+    const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_NL;
 
-    device block_iq4_nl * dst_data = (device block_iq4_nl *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+    device block_iq4_nl * dst_data = (device block_iq4_nl *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
 
-    for (int64_t i00 = tpitg.x*QK4_NL; i00 < ne00; i00 += ntg.x*QK4_NL) {
-        device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+    for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) {
+        device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
 
         float amax = 0.0f; // absolute max
         float max  = 0.0f;
@@ -4159,104 +3685,66 @@ kernel void kernel_cpy_f32_iq4_nl(
         }
 
         dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d;
-
     }
 }
 
 kernel void kernel_concat(
+    constant ggml_metal_kargs_concat & args,
     device  const char * src0,
     device  const char * src1,
     device        char * dst,
-    constant   int64_t & ne00,
-    constant   int64_t & ne01,
-    constant   int64_t & ne02,
-    constant   int64_t & ne03,
-    constant  uint64_t & nb00,
-    constant  uint64_t & nb01,
-    constant  uint64_t & nb02,
-    constant  uint64_t & nb03,
-    constant   int64_t & ne10,
-    constant   int64_t & ne11,
-    constant   int64_t & ne12,
-    constant   int64_t & ne13,
-    constant  uint64_t & nb10,
-    constant  uint64_t & nb11,
-    constant  uint64_t & nb12,
-    constant  uint64_t & nb13,
-    constant   int64_t & ne0,
-    constant   int64_t & ne1,
-    constant   int64_t & ne2,
-    constant   int64_t & ne3,
-    constant  uint64_t & nb0,
-    constant  uint64_t & nb1,
-    constant  uint64_t & nb2,
-    constant  uint64_t & nb3,
-    constant   int32_t & dim,
-    uint3 tgpig[[threadgroup_position_in_grid]],
-    uint3 tpitg[[thread_position_in_threadgroup]],
-    uint3   ntg[[threads_per_threadgroup]]) {
+    uint3   tgpig[[threadgroup_position_in_grid]],
+    ushort3 tpitg[[thread_position_in_threadgroup]],
+    ushort3   ntg[[threads_per_threadgroup]]) {
 
-    const int64_t i3 = tgpig.z;
-    const int64_t i2 = tgpig.y;
-    const int64_t i1 = tgpig.x;
+    const int i3 = tgpig.z;
+    const int i2 = tgpig.y;
+    const int i1 = tgpig.x;
 
-    int64_t o[4] = {0, 0, 0, 0};
-    o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
+    int o[4] = {0, 0, 0, 0};
+    o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03));
 
     device const float * x;
 
-    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
-        if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
-            x = (device const float *)(src0 + (i3       )*nb03 + (i2       )*nb02 + (i1       )*nb01 + (i0       )*nb00);
+    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+        if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
+            x = (device const float *)(src0 + (i3       )*args.nb03 + (i2       )*args.nb02 + (i1       )*args.nb01 + (i0       )*args.nb00);
         } else {
-            x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
+            x = (device const float *)(src1 + (i3 - o[3])*args.nb13 + (i2 - o[2])*args.nb12 + (i1 - o[1])*args.nb11 + (i0 - o[0])*args.nb10);
         }
 
-        device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+        device float * y = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
 
         *y = *x;
     }
 }
 
+template<typename args_t>
 void kernel_mul_mv_q2_K_f32_impl(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne12,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-        threadgroup int8_t * shared_values,
-                   uint3     tgpig,
-                   uint      tiisg,
-                   uint      sgitg) {
-
-    const int nb = ne00/QK_K;
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    const int nb = args.ne00/QK_K;
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
 
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
 
-    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
 
-    device const block_q2_K * x = (device const block_q2_K *) ((device char *) src0 + offset0);
-    device const float      * y = (device const float      *) ((device char *) src1 + offset1);
+    device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0);
+    device const float      * y = (device const float      *) (src1 + offset1);
 
     float yl[32];
     float sumf[N_DST]={0.f}, all_sum;
@@ -4305,92 +3793,64 @@ void kernel_mul_mv_q2_K_f32_impl(
                                  (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
                          dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
 
-            qs += nb01/2;
-            sc += nb01;
-            dh += nb01/2;
+            qs += args.nb01/2;
+            sc += args.nb01;
+            dh += args.nb01/2;
         }
 
         y4 += 4 * QK_K;
     }
 
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
     for (int row = 0; row < N_DST; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+            dst_f32[first_row + row] = all_sum;
         }
     }
 }
 
 [[host_name("kernel_mul_mv_q2_K_f32")]]
 kernel void kernel_mul_mv_q2_K_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q2_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
+template<typename args_t>
 void kernel_mul_mv_q3_K_f32_impl(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne12,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-        threadgroup int8_t * shared_values,
-                   uint3     tgpig,
-                   uint      tiisg,
-                   uint      sgitg) {
-
-    const int nb = ne00/QK_K;
-
-    const int64_t r0 = tgpig.x;
-    const int64_t r1 = tgpig.y;
-    const int64_t im = tgpig.z;
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    const int nb = args.ne00/QK_K;
+
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
 
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
 
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
 
-    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
 
-    device const block_q3_K * x = (device const block_q3_K *) ((device char *) src0 + offset0);
-    device const float     * yy = (device const float      *) ((device char *) src1 + offset1);
+    device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0);
+    device const float     * yy = (device const float      *) (src1 + offset1);
 
     float yl[32];
 
@@ -4420,9 +3880,10 @@ void kernel_mul_mv_q3_K_f32_impl(
 
     const ushort4 hm = mm[2*ip + il/2];
 
-    const int shift = 2*il;
-    const float    v1 = il == 0 ? 4.f : 64.f;
-    const float    v2 = 4.f * v1;
+    const short shift = 2*il;
+
+    const float v1 = il == 0 ? 4.f : 64.f;
+    const float v2 = 4.f * v1;
 
     const uint16_t s_shift1 = 4*ip;
     const uint16_t s_shift2 = s_shift1 + il;
@@ -4491,10 +3952,10 @@ void kernel_mul_mv_q3_K_f32_impl(
             sumf1[row] += d1 * (scales[1] - 32);
             sumf2[row] += d2 * (scales[3] - 32);
 
-            q  += nb01/2;
-            h  += nb01/2;
-            a  += nb01/2;
-            dh += nb01/2;
+            q  += args.nb01/2;
+            h  += args.nb01/2;
+            a  += args.nb01/2;
+            dh += args.nb01/2;
         }
 
         y1 += 4 * QK_K;
@@ -4504,66 +3965,39 @@ void kernel_mul_mv_q3_K_f32_impl(
         const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
         sumf1[row] = simd_sum(sumf);
     }
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
     if (tiisg == 0) {
         for (int row = 0; row < 2; ++row) {
-            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row];
+            dst_f32[first_row + row] = sumf1[row];
         }
     }
 }
 
 [[host_name("kernel_mul_mv_q3_K_f32")]]
 kernel void kernel_mul_mv_q3_K_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q3_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
+template<typename args_t>
 void kernel_mul_mv_q4_K_f32_impl(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne12,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-        threadgroup int8_t * shared_values,
-                   uint3     tgpig,
-                   uint      tiisg,
-                   uint      sgitg) {
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
 
     const uint16_t kmask1 = 0x3f3f;
     const uint16_t kmask2 = 0x0f0f;
@@ -4574,21 +4008,21 @@ void kernel_mul_mv_q4_K_f32_impl(
     const int iq = it/4;     // 0 or 1
     const int ir = it%4;     // 0...3
 
-    const int nb = ne00/QK_K;
+    const int nb = args.ne00/QK_K;
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
     //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
     const int first_row = r0 * N_DST;
 
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
 
-    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
 
-    device const block_q4_K * x = (device const block_q4_K *) ((device char *) src0 + offset0);
-    device const float      * y = (device const float      *) ((device char *) src1 + offset1);
+    device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0);
+    device const float      * y = (device const float      *) (src1 + offset1);
 
     float yl[16];
     float yh[16];
@@ -4641,92 +4075,64 @@ void kernel_mul_mv_q4_K_f32_impl(
                                  (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
                          dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
 
-            q1 += nb01/2;
-            sc += nb01/2;
-            dh += nb01/2;
+            q1 += args.nb01/2;
+            sc += args.nb01/2;
+            dh += args.nb01/2;
         }
 
         y4 += 4 * QK_K;
     }
 
+    device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
+
     for (int row = 0; row < N_DST; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+            dst_f32[first_row + row] = all_sum;
         }
     }
 }
 
 [[host_name("kernel_mul_mv_q4_K_f32")]]
 kernel void kernel_mul_mv_q4_K_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]],
-        uint sgitg[[simdgroup_index_in_threadgroup]]) {
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q4_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
+template<typename args_t>
 void kernel_mul_mv_q5_K_f32_impl(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne12,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-        threadgroup int8_t * shared_values,
-                   uint3     tgpig,
-                   uint      tiisg,
-                   uint      sgitg) {
-
-    const int nb = ne00/QK_K;
-
-    const int64_t r0 = tgpig.x;
-    const int64_t r1 = tgpig.y;
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    const int nb = args.ne00/QK_K;
+
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
     const int im = tgpig.z;
 
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
 
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
 
-    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
 
-    device const block_q5_K * x = (device const block_q5_K *) ((device char *) src0 + offset0);
-    device const float     * yy = (device const float      *) ((device char *) src1 + offset1);
+    device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0);
+    device const float     * yy = (device const float      *) (src1 + offset1);
 
     float sumf[2]={0.f};
 
@@ -4800,98 +4206,70 @@ void kernel_mul_mv_q5_K_f32_impl(
                                  sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
                          dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
 
-            q1 += nb01;
-            qh += nb01;
-            dh += nb01/2;
-            a  += nb01/2;
+            q1 += args.nb01;
+            qh += args.nb01;
+            dh += args.nb01/2;
+            a  += args.nb01/2;
         }
 
         y1 += 4 * QK_K;
     }
 
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
     for (int row = 0; row < 2; ++row) {
         const float tot = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
+            dst_f32[first_row + row] = tot;
         }
     }
 }
 
 [[host_name("kernel_mul_mv_q5_K_f32")]]
 kernel void kernel_mul_mv_q5_K_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q5_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
+template <typename args_t>
 void kernel_mul_mv_q6_K_f32_impl(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne12,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-        threadgroup int8_t * shared_values,
-                   uint3     tgpig,
-                   uint      tiisg,
-                   uint      sgitg) {
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
 
     const uint8_t kmask1 = 0x03;
     const uint8_t kmask2 = 0x0C;
     const uint8_t kmask3 = 0x30;
     const uint8_t kmask4 = 0xC0;
 
-    const int nb = ne00/QK_K;
+    const int nb = args.ne00/QK_K;
 
-    const int64_t r0 = tgpig.x;
-    const int64_t r1 = tgpig.y;
-    const int     im = tgpig.z;
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
 
-    const int row = 2 * r0 + sgitg;
+    const int row = 2*r0 + sgitg;
 
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
 
-    const uint offset0 = row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-    const uint offset1 =  r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+    const uint64_t offset0 = row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =  r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
 
-    device const block_q6_K * x = (device const block_q6_K *) ((device char *) src0 + offset0);
-    device const float     * yy = (device const float      *) ((device char *) src1 + offset1);
+    device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0);
+    device const float     * yy = (device const float      *) (src1 + offset1);
 
     float sumf = 0;
 
@@ -4908,7 +4286,6 @@ void kernel_mul_mv_q6_K_f32_impl(
     const int q_offset_h = 32*ip + l0;
 
     for (int i = ix; i < nb; i += 2) {
-
         device const uint8_t * q1 = x[i].ql + q_offset_l;
         device const uint8_t * q2 = q1 + 32;
         device const uint8_t * qh = x[i].qh + q_offset_h;
@@ -4930,98 +4307,70 @@ void kernel_mul_mv_q6_K_f32_impl(
 
     }
 
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
     const float tot = simd_sum(sumf);
     if (tiisg == 0) {
-        dst[r1*ne0 + im*ne0*ne1 + row] = tot;
+        dst_f32[row] = tot;
     }
 }
 
 [[host_name("kernel_mul_mv_q6_K_f32")]]
 kernel void kernel_mul_mv_q6_K_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q6_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
 // ======================= "True" 2-bit
 
+template<typename args_t>
 void kernel_mul_mv_iq2_xxs_f32_impl(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne12,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-        threadgroup int8_t * shared_values,
-                   uint3     tgpig,
-                   uint      tiisg,
-                   uint      sgitg) {
-
-    const int nb = ne00/QK_K;
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    const int nb = args.ne00/QK_K;
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
 
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
 
-    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
 
-    device const block_iq2_xxs * x = (device const block_iq2_xxs *) ((device char *) src0 + offset0);
-    device const float         * y = (device const float         *) ((device char *) src1 + offset1);
+    device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0);
+    device const float         * y = (device const float         *) (src1 + offset1);
 
     float yl[32];
     float sumf[N_DST]={0.f}, all_sum;
 
     const int nb32 = nb * (QK_K / 32);
 
-    threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
-    threadgroup uint8_t  * shared_signs = (threadgroup uint8_t *)(values + 256);
+    threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem);
+    threadgroup uint8_t  * ssigns  = (threadgroup uint8_t  *)(svalues + 256);
     {
         int nval = 4;
         int pos  = (32*sgitg + tiisg)*nval;
-        for (int i = 0; i < nval; ++i) values[pos + i] = iq2xxs_grid[pos + i];
+        for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xxs_grid[pos + i];
         nval = 2;
         pos  = (32*sgitg + tiisg)*nval;
-        for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
+        for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i];
         threadgroup_barrier(mem_flags::mem_threadgroup);
     }
 
@@ -5051,114 +4400,85 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
 
             float sum = 0;
             for (int l = 0; l < 4; ++l) {
-                const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]);
-                const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
+                const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]);
+                const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
                 for (int j = 0; j < 8; ++j) {
                     sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
                 }
             }
             sumf[row] += d * sum;
 
-            dh += nb01/2;
-            q2 += nb01/2;
+            dh += args.nb01/2;
+            q2 += args.nb01/2;
         }
 
         y4 += 32 * 32;
     }
 
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
     for (int row = 0; row < N_DST; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
+            dst_f32[first_row + row] = all_sum * 0.25f;
         }
     }
 }
 
 [[host_name("kernel_mul_mv_iq2_xxs_f32")]]
 kernel void kernel_mul_mv_iq2_xxs_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        threadgroup int8_t * shared_values [[threadgroup(0)]],
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-
-    kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+    kernel_mul_mv_iq2_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
+template<typename args_t>
 void kernel_mul_mv_iq2_xs_f32_impl(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne12,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-        threadgroup int8_t * shared_values,
-                   uint3     tgpig,
-                   uint      tiisg,
-                   uint      sgitg) {
-
-    const int nb = ne00/QK_K;
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    const int nb = args.ne00/QK_K;
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
 
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
 
-    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
 
-    device const block_iq2_xs * x = (device const block_iq2_xs *) ((device char *) src0 + offset0);
-    device const float        * y = (device const float        *) ((device char *) src1 + offset1);
+    device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0);
+    device const float        * y = (device const float        *) (src1 + offset1);
 
     float yl[32];
     float sumf[N_DST]={0.f}, all_sum;
 
     const int nb32 = nb * (QK_K / 32);
 
-    threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
-    threadgroup uint8_t  * shared_signs = (threadgroup uint8_t *)(values + 512);
+    threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem);
+    threadgroup uint8_t  * ssigns  = (threadgroup uint8_t  *)(svalues + 512);
     {
         int nval = 8;
         int pos  = (32*sgitg + tiisg)*nval;
-        for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i];
+        for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xs_grid[pos + i];
         nval = 2;
         pos  = (32*sgitg + tiisg)*nval;
-        for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
+        for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i];
         threadgroup_barrier(mem_flags::mem_threadgroup);
     }
 
@@ -5190,122 +4510,94 @@ void kernel_mul_mv_iq2_xs_f32_impl(
 
             float sum1 = 0, sum2 = 0;
             for (int l = 0; l < 2; ++l) {
-                const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
-                const uint8_t signs = shared_signs[(q2[l] >> 9)];
+                const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
+                const uint8_t signs = ssigns[(q2[l] >> 9)];
                 for (int j = 0; j < 8; ++j) {
                     sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
                 }
             }
             for (int l = 2; l < 4; ++l) {
-                const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
-                const uint8_t signs = shared_signs[(q2[l] >> 9)];
+                const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
+                const uint8_t signs = ssigns[(q2[l] >> 9)];
                 for (int j = 0; j < 8; ++j) {
                     sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
                 }
             }
             sumf[row] += d1 * sum1 + d2 * sum2;
 
-            dh += nb01/2;
-            q2 += nb01/2;
-            sc += nb01;
+            dh += args.nb01/2;
+            q2 += args.nb01/2;
+            sc += args.nb01;
         }
 
         y4 += 32 * 32;
     }
 
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
     for (int row = 0; row < N_DST; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
+            dst_f32[first_row + row] = all_sum * 0.25f;
         }
-    }
-}
-
-[[host_name("kernel_mul_mv_iq2_xs_f32")]]
-kernel void kernel_mul_mv_iq2_xs_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        threadgroup int8_t * shared_values [[threadgroup(0)]],
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+    }
+}
+
+[[host_name("kernel_mul_mv_iq2_xs_f32")]]
+kernel void kernel_mul_mv_iq2_xs_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq2_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
+template <typename args_t>
 void kernel_mul_mv_iq3_xxs_f32_impl(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne12,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-        threadgroup int8_t * shared_values,
-                   uint3     tgpig,
-                   uint      tiisg,
-                   uint      sgitg) {
-
-    const int nb = ne00/QK_K;
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    const int nb = args.ne00/QK_K;
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
 
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
 
-    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
 
-    device const block_iq3_xxs * x = (device const block_iq3_xxs *) ((device char *) src0 + offset0);
-    device const float         * y = (device const float         *) ((device char *) src1 + offset1);
+    device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0);
+    device const float         * y = (device const float         *) (src1 + offset1);
 
     float yl[32];
     float sumf[N_DST]={0.f}, all_sum;
 
     const int nb32 = nb * (QK_K / 32);
 
-    threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
-    threadgroup uint8_t  * shared_signs = (threadgroup uint8_t *)(values + 256);
+    threadgroup uint32_t * svalues = (threadgroup uint32_t *)(shmem);
+    threadgroup uint8_t  * ssigns  = (threadgroup uint8_t  *)(svalues + 256);
     {
         int nval = 4;
         int pos  = (32*sgitg + tiisg)*nval;
-        for (int i = 0; i < nval; ++i) values[pos + i] = iq3xxs_grid[pos + i];
+        for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3xxs_grid[pos + i];
         nval = 2;
         pos  = (32*sgitg + tiisg)*nval;
-        for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
+        for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i];
         threadgroup_barrier(mem_flags::mem_threadgroup);
     }
 
@@ -5314,7 +4606,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
     device const float * y4 = y + 32 * ix;
 
     for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
         for (int i = 0; i < 32; ++i) {
             yl[i] = y4[i];
         }
@@ -5328,16 +4619,15 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
         device const half * dh = &xr->d;
 
         for (int row = 0; row < N_DST; row++) {
-
             const float db = dh[0];
             const uint32_t aux32 = gas[0] | (gas[1] << 16);
             const float d = db * (0.5f + (aux32 >> 28));
 
             float2 sum = {0};
             for (int l = 0; l < 4; ++l) {
-                const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + q3[2*l+0]);
-                const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + q3[2*l+1]);
-                const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
+                const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]);
+                const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]);
+                const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
                 for (int j = 0; j < 4; ++j) {
                     sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
                     sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
@@ -5345,103 +4635,75 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
             }
             sumf[row] += d * (sum[0] + sum[1]);
 
-            dh  += nb01/2;
-            q3  += nb01;
-            gas += nb01/2;
+            dh  += args.nb01/2;
+            q3  += args.nb01;
+            gas += args.nb01/2;
         }
 
         y4 += 32 * 32;
     }
 
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
     for (int row = 0; row < N_DST; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f;
+            dst_f32[first_row + row] = all_sum * 0.5f;
         }
     }
 }
 
 [[host_name("kernel_mul_mv_iq3_xxs_f32")]]
 kernel void kernel_mul_mv_iq3_xxs_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        threadgroup int8_t * shared_values [[threadgroup(0)]],
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq3_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
+template<typename args_t>
 void kernel_mul_mv_iq3_s_f32_impl(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne12,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-        threadgroup int8_t * shared_values,
-                   uint3     tgpig,
-                   uint      tiisg,
-                   uint      sgitg) {
-
-    const int nb = ne00/QK_K;
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    const int nb = args.ne00/QK_K;
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
 
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
 
-    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
 
-    device const block_iq3_s * x = (device const block_iq3_s *) ((device char *) src0 + offset0);
-    device const float       * y = (device const float       *) ((device char *) src1 + offset1);
+    device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0);
+    device const float       * y = (device const float       *) (src1 + offset1);
 
     float yl[32];
     float sumf[N_DST]={0.f}, all_sum;
 
     const int nb32 = nb * (QK_K / 32);
 
-    threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
+    threadgroup uint32_t * svalues = (threadgroup uint32_t *) shmem;
     {
         int nval = 8;
         int pos  = (32*sgitg + tiisg)*nval;
-        for (int i = 0; i < nval; ++i) values[pos + i] = iq3s_grid[pos + i];
+        for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3s_grid[pos + i];
         threadgroup_barrier(mem_flags::mem_threadgroup);
     }
 
@@ -5472,8 +4734,8 @@ void kernel_mul_mv_iq3_s_f32_impl(
 
             float2 sum = {0};
             for (int l = 0; l < 4; ++l) {
-                const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? values + 256 : values;
-                const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? values + 256 : values;
+                const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues;
+                const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues;
                 const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);
                 const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);
                 for (int j = 0; j < 4; ++j) {
@@ -5483,105 +4745,77 @@ void kernel_mul_mv_iq3_s_f32_impl(
             }
             sumf[row] += d * (sum[0] + sum[1]);
 
-            dh    += nb01/2;
-            qs    += nb01;
-            qh    += nb01;
-            sc    += nb01;
-            signs += nb01;
+            dh    += args.nb01/2;
+            qs    += args.nb01;
+            qh    += args.nb01;
+            sc    += args.nb01;
+            signs += args.nb01;
         }
 
         y4 += 32 * 32;
     }
 
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
     for (int row = 0; row < N_DST; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+            dst_f32[first_row + row] = all_sum;
         }
     }
 }
 
 [[host_name("kernel_mul_mv_iq3_s_f32")]]
 kernel void kernel_mul_mv_iq3_s_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        threadgroup int8_t * shared_values [[threadgroup(0)]],
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq3_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
+template <typename args_t>
 void kernel_mul_mv_iq2_s_f32_impl(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne12,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-        threadgroup int8_t * shared_values,
-                   uint3     tgpig,
-                   uint      tiisg,
-                   uint      sgitg) {
-
-    const int nb = ne00/QK_K;
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    const int nb = args.ne00/QK_K;
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
 
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
 
-    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
 
-    device const block_iq2_s * x = (device const block_iq2_s *) ((device char *) src0 + offset0);
-    device const float       * y = (device const float       *) ((device char *) src1 + offset1);
+    device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0);
+    device const float       * y = (device const float       *) (src1 + offset1);
 
     float yl[32];
     float sumf[N_DST]={0.f}, all_sum;
 
     const int nb32 = nb * (QK_K / 32);
 
-    //threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
+    //threadgroup uint64_t * svalues = (threadgroup uint64_t *) shmem;
     //{
     //    int nval = 32;
     //    int pos  = (32*sgitg + tiisg)*nval;
-    //    for (int i = 0; i < nval; ++i) values[pos + i] = iq2s_grid[pos + i];
+    //    for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2s_grid[pos + i];
     //    threadgroup_barrier(mem_flags::mem_threadgroup);
     //}
 
@@ -5613,8 +4847,8 @@ void kernel_mul_mv_iq2_s_f32_impl(
 
             float2 sum = {0};
             for (int l = 0; l < 2; ++l) {
-                //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
-                //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
+                //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
+                //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
                 constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
                 constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
                 for (int j = 0; j < 8; ++j) {
@@ -5624,94 +4858,66 @@ void kernel_mul_mv_iq2_s_f32_impl(
             }
             sumf[row] += d1 * sum[0] + d2 * sum[1];
 
-            dh    += nb01/2;
-            qs    += nb01;
-            qh    += nb01;
-            sc    += nb01;
-            signs += nb01;
+            dh    += args.nb01/2;
+            qs    += args.nb01;
+            qh    += args.nb01;
+            sc    += args.nb01;
+            signs += args.nb01;
         }
 
         y4 += 32 * 32;
     }
 
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
     for (int row = 0; row < N_DST; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
+            dst_f32[first_row + row] = all_sum * 0.25f;
         }
     }
 }
 
 [[host_name("kernel_mul_mv_iq2_s_f32")]]
 kernel void kernel_mul_mv_iq2_s_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        threadgroup int8_t * shared_values [[threadgroup(0)]],
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq2_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
+template<typename args_t>
 void kernel_mul_mv_iq1_s_f32_impl(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne12,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-        threadgroup int8_t * shared_value,
-                   uint3     tgpig,
-                   uint      tiisg,
-                   uint      sgitg) {
-
-    const int nb = ne00/QK_K;
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    const int nb = args.ne00/QK_K;
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
 
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
 
-    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
 
-    device const block_iq1_s * x = (device const block_iq1_s *) ((device char *) src0 + offset0);
-    device const float       * y = (device const float       *) ((device char *) src1 + offset1);
+    device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0);
+    device const float       * y = (device const float       *) (src1 + offset1);
 
     float yl[32];
     float sumf[N_DST]={0.f}, all_sum;
@@ -5754,61 +4960,50 @@ void kernel_mul_mv_iq1_s_f32_impl(
             }
             sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1);
 
-            dh += nb01/2;
-            qs += nb01;
-            qh += nb01/2;
+            dh += args.nb01/2;
+            qs += args.nb01;
+            qh += args.nb01/2;
         }
 
         y4 += 32 * 32;
     }
 
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
     for (int row = 0; row < N_DST; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+            dst_f32[first_row + row] = all_sum;
         }
     }
 }
 
+template <typename args_t>
 void kernel_mul_mv_iq1_m_f32_impl(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne12,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-        threadgroup int8_t * shared_value,
-                   uint3     tgpig,
-                   uint      tiisg,
-                   uint      sgitg) {
-
-    const int nb = ne00/QK_K;
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    const int nb = args.ne00/QK_K;
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
 
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
 
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
 
-    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
 
-    device const block_iq1_m * x = (device const block_iq1_m *) ((device char *) src0 + offset0);
-    device const float       * y = (device const float       *) ((device char *) src1 + offset1);
+    device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0);
+    device const float       * y = (device const float       *) (src1 + offset1);
 
     float yl[32];
     float sumf[N_DST]={0.f}, all_sum;
@@ -5860,66 +5055,55 @@ void kernel_mul_mv_iq1_m_f32_impl(
             sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
                                              (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
 
-            sc += nb01/2;
-            qs += nb01;
-            qh += nb01;
+            sc += args.nb01/2;
+            qs += args.nb01;
+            qh += args.nb01;
         }
 
         y4 += 32 * 32;
     }
 
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
     for (int row = 0; row < N_DST; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+            dst_f32[first_row + row] = all_sum;
         }
     }
 }
 
+template<typename args_t>
 void kernel_mul_mv_iq4_nl_f32_impl(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne12,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-        threadgroup int8_t * shared_values_i8,
-                   uint3     tgpig,
-                   uint      tiisg,
-                   uint      sgitg) {
-
-    threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
-    const int nb = ne00/QK4_NL;
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    threadgroup float * shmem_f32 = (threadgroup float *) shmem;
+    const int nb = args.ne00/QK4_NL;
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
     const int first_row = (r0 * 2 + sgitg) * 2;
 
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
 
-    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
 
-    device const block_iq4_nl * x = (device const block_iq4_nl *) ((device char *) src0 + offset0);
-    device const float        * y = (device const float        *) ((device char *) src1 + offset1);
+    device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
+    device const float        * y = (device const float        *) (src1 + offset1);
 
     const int ix = tiisg/2;  // 0...15
     const int it = tiisg%2;  // 0 or 1
 
-    shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
+    shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
     float4 yl[4];
@@ -5937,7 +5121,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
         device const float4 * y4 = (device const float4 *)yb;
         yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
 
-        for (int row = 0; row < 2 && first_row + row < ne01; ++row) {
+        for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) {
 
             device const block_iq4_nl & xb = x[row*nb + ib];
             device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
@@ -5947,16 +5131,16 @@ void kernel_mul_mv_iq4_nl_f32_impl(
             aux32[0] = q4[0] | (q4[1] << 16);
             aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
             aux32[0] &= 0x0f0f0f0f;
-            qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
-            qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
+            qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
+            qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
             acc1 += yl[0] * qf1;
             acc2 += yl[1] * qf2;
 
             aux32[0] = q4[2] | (q4[3] << 16);
             aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
             aux32[0] &= 0x0f0f0f0f;
-            qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
-            qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
+            qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
+            qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
             acc1 += yl[2] * qf1;
             acc2 += yl[3] * qf2;
 
@@ -5969,60 +5153,49 @@ void kernel_mul_mv_iq4_nl_f32_impl(
         yb += 16 * QK4_NL;
     }
 
-    for (int row = 0; row < 2 && first_row + row < ne01; ++row) {
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+    for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+            dst_f32[first_row + row] = all_sum;
         }
     }
 }
 
+template<typename args_t>
 void kernel_mul_mv_iq4_xs_f32_impl(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne12,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-        threadgroup int8_t * shared_values_i8,
-                   uint3     tgpig,
-                   uint      tiisg,
-                   uint      sgitg) {
-
-    threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
-    const int nb = ne00/QK_K;
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    threadgroup float * shmem_f32 = (threadgroup float *) shmem;
+    const int nb = args.ne00/QK_K;
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
     const int first_row = (r0 * 2 + sgitg) * 2;
 
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
 
-    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
-    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
 
-    device const block_iq4_xs * x = (device const block_iq4_xs *) ((device char *) src0 + offset0);
-    device const float        * y = (device const float        *) ((device char *) src1 + offset1);
+    device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
+    device const float        * y = (device const float        *) (src1 + offset1);
 
     const int ix = tiisg/16;  // 0 or 1
     const int it = tiisg%16;  // 0...15
     const int ib = it/2;
     const int il = it%2;
 
-    shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
+    shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
     float4 yl[4];
@@ -6036,28 +5209,26 @@ void kernel_mul_mv_iq4_xs_f32_impl(
     float4 qf1, qf2;
 
     for (int ibl = ix; ibl < nb; ibl += 2) {
-
         device const float4 * y4 = (device const float4 *)yb;
         yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
 
         for (int row = 0; row < 2; ++row) {
-
             device const block_iq4_xs & xb = x[row*nb + ibl];
             device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
 
             float4 acc1 = {0.f}, acc2 = {0.f};
 
-            aux32[0] = q4[0] & 0x0f0f0f0f;
+            aux32[0] = (q4[0]     ) & 0x0f0f0f0f;
             aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f;
-            qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
-            qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
+            qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
+            qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
             acc1 += yl[0] * qf1;
             acc2 += yl[1] * qf2;
 
-            aux32[0] = q4[1] & 0x0f0f0f0f;
+            aux32[0] = (q4[1]     ) & 0x0f0f0f0f;
             aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f;
-            qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
-            qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
+            qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
+            qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
             acc1 += yl[2] * qf1;
             acc2 += yl[3] * qf2;
 
@@ -6071,134 +5242,68 @@ void kernel_mul_mv_iq4_xs_f32_impl(
         yb += 2 * QK_K;
     }
 
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
     for (int row = 0; row < 2; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
-            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+            dst_f32[first_row + row] = all_sum;
         }
     }
 }
 
 [[host_name("kernel_mul_mv_iq1_s_f32")]]
 kernel void kernel_mul_mv_iq1_s_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq1_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
 [[host_name("kernel_mul_mv_iq1_m_f32")]]
 kernel void kernel_mul_mv_iq1_m_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint  tiisg[[thread_index_in_simdgroup]],
-        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq1_m_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
 }
 
 [[host_name("kernel_mul_mv_iq4_nl_f32")]]
 kernel void kernel_mul_mv_iq4_nl_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        threadgroup int8_t * shared_values [[threadgroup(0)]],
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]],
-        uint sgitg[[simdgroup_index_in_threadgroup]]) {
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq4_nl_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
 [[host_name("kernel_mul_mv_iq4_xs_f32")]]
 kernel void kernel_mul_mv_iq4_xs_f32(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant   int64_t & ne12,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb12,
-        constant  uint64_t & nb13,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   uint    & r2,
-        constant   uint    & r3,
-        threadgroup int8_t * shared_values [[threadgroup(0)]],
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]],
-        uint sgitg[[simdgroup_index_in_threadgroup]]) {
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq4_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
 template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
@@ -6302,38 +5407,26 @@ kernel void kernel_get_rows_i32(
 
 // each block_q contains 16*nl weights
 template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
-kernel void kernel_mul_mm(device const  uchar * src0,
-                          device const  uchar * src1,
-                          device        float * dst,
-                          constant    int64_t & ne00,
-                          constant    int64_t & ne02,
-                          constant   uint64_t & nb01,
-                          constant   uint64_t & nb02,
-                          constant   uint64_t & nb03,
-                          constant    int64_t & ne12,
-                          constant   uint64_t & nb10,
-                          constant   uint64_t & nb11,
-                          constant   uint64_t & nb12,
-                          constant   uint64_t & nb13,
-                          constant    int64_t & ne0,
-                          constant    int64_t & ne1,
-                          constant       uint & r2,
-                          constant       uint & r3,
-                          threadgroup   uchar * shared_memory [[threadgroup(0)]],
-                          uint3                 tgpig[[threadgroup_position_in_grid]],
-                          uint                  tiitg[[thread_index_in_threadgroup]],
-                          uint                  sgitg[[simdgroup_index_in_threadgroup]]) {
-
-    threadgroup T     * sa = (threadgroup T     *)(shared_memory);
-    threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
-
-    const uint r0 = tgpig.y;
-    const uint r1 = tgpig.x;
-    const uint im = tgpig.z;
+kernel void kernel_mul_mm(
+        constant ggml_metal_kargs_mul_mm & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiitg[[thread_index_in_threadgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    threadgroup T     * sa = (threadgroup T     *)(shmem);
+    threadgroup float * sb = (threadgroup float *)(shmem + 4096);
+
+    const int r0 = tgpig.y;
+    const int r1 = tgpig.x;
+    const int im = tgpig.z;
 
     // if this block is of 64x32 shape or smaller
-    short n_rows = (ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
-    short n_cols = (ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
+    short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
+    short n_cols = (args.ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (args.ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
 
     // a thread shouldn't load data outside of the matrix
     short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
@@ -6349,20 +5442,20 @@ kernel void kernel_mul_mm(device const  uchar * src0,
 
     short il = (tiitg % THREAD_PER_ROW);
 
-    const uint i12 = im%ne12;
-    const uint i13 = im/ne12;
+    const int i12 = im%args.ne12;
+    const int i13 = im/args.ne12;
 
-    uint   offset0 = (i12/r2)*nb02 + (i13/r3)*nb03;
-    ushort offset1 = il/nl;
+    uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    short    offset1 = il/nl;
 
-    device const block_q * x = (device const block_q *)(src0 + (r0*BLOCK_SIZE_M + thread_row)*nb01 + offset0) + offset1;
+    device const block_q * x = (device const block_q *)(src0 + (r0*BLOCK_SIZE_M + thread_row)*args.nb01 + offset0) + offset1;
     device const float   * y = (device const float   *)(src1
-        + nb13 * i13
-        + nb12 * i12
-        + nb11 * (r1 * BLOCK_SIZE_N + thread_col)
-        + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
+        + args.nb13*i13
+        + args.nb12*i12
+        + args.nb11*(r1 * BLOCK_SIZE_N + thread_col)
+        + args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
 
-    for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
+    for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
         // load data and store to threadgroup memory
         T4x4 temp_a;
         dequantize_func(x, il, temp_a);
@@ -6409,16 +5502,18 @@ kernel void kernel_mul_mm(device const  uchar * src0,
         }
     }
 
-    if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
-        device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg &  1)) \
-                               + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
+    if ((r0 + 1) * BLOCK_SIZE_M <= args.ne0 && (r1 + 1) * BLOCK_SIZE_N <= args.ne1) {
+        device float * C = (device float *) dst +
+            (BLOCK_SIZE_M * r0 + 32 * (sgitg &  1)) + \
+            (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
+
         for (short i = 0; i < 8; i++) {
-            simdgroup_store(mc[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
+            simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.ne0 * (i/4), args.ne0);
         }
     } else {
         // block is smaller than 64x32, we should avoid writing data outside of the matrix
         threadgroup_barrier(mem_flags::mem_threadgroup);
-        threadgroup float * temp_str = ((threadgroup float *) shared_memory) \
+        threadgroup float * temp_str = ((threadgroup float *) shmem) \
                                       + 32 * (sgitg&1) + (16 * (sgitg>>1))*BLOCK_SIZE_M;
         for (short i = 0; i < 8; i++) {
             simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
@@ -6428,7 +5523,7 @@ kernel void kernel_mul_mm(device const  uchar * src0,
 
         if (sgitg == 0) {
             for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
-                device float  * D  = dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*ne0 + im*ne1*ne0;
+                device float  * D  = (device float  *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.ne0 + im*args.ne1*args.ne0;
                 device float4 * D4 = (device float4 *) D;
 
                 threadgroup float  * C  = temp_str + (j*BLOCK_SIZE_M);
@@ -6449,36 +5544,37 @@ kernel void kernel_mul_mm(device const  uchar * src0,
 }
 
 // same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids
+// TODO: this kernel needs to be reimplemented from scratch for better performance
 template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
 void kernel_mul_mm_id_impl(
-        device const  uchar * src0,
-        device const  uchar * src1,
+        int32_t  ne00,
+        int32_t  ne02,
+        uint64_t nb01,
+        uint64_t nb02,
+        int32_t  ne11,
+        int32_t  ne12,
+        uint64_t nb10,
+        uint64_t nb11,
+        uint64_t nb12,
+        int32_t  ne0,
+        int32_t  ne1,
+        int64_t  ne0ne1,
+        device   const char * src0,
+        device   const char * src1,
         threadgroup ushort2 * rowids,
-        device        float * dst,
-        constant    int64_t & ne00,
-        constant    int64_t & ne02,
-        constant   uint64_t & nb01,
-        constant   uint64_t & nb02,
-        constant    int64_t & ne11,
-        constant    int64_t & ne12,
-        constant   uint64_t & nb10,
-        constant   uint64_t & nb11,
-        constant   uint64_t & nb12,
-        constant    int64_t & ne0,
-                    int64_t   ne1,
-                    int64_t   ne0ne1,
-        threadgroup   uchar * shared_memory,
-        uint3                 tgpig[[threadgroup_position_in_grid]],
-        uint                  tiitg[[thread_index_in_threadgroup]],
-        uint                  sgitg[[simdgroup_index_in_threadgroup]]) {
-
-    threadgroup half  * sa = (threadgroup half  *)(shared_memory);
-    threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
-
-    const uint r0 = tgpig.y;
-    const uint r1 = tgpig.x;
-
-    if (r1 * BLOCK_SIZE_N >= ne1) return;
+        device         char * dst,
+        threadgroup    char * shmem,
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiitg[[thread_index_in_threadgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    threadgroup half  * sa = (threadgroup half  *)(shmem);
+    threadgroup float * sb = (threadgroup float *)(shmem + 4096);
+
+    const int r0 = tgpig.y;
+    const int r1 = tgpig.x;
+
+    if (r1*BLOCK_SIZE_N >= ne1) return;
 
     // if this block is of 64x32 shape or smaller
     short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
@@ -6490,9 +5586,9 @@ void kernel_mul_mm_id_impl(
 
     simdgroup_half8x8  ma[4];
     simdgroup_float8x8 mb[2];
-    simdgroup_float8x8 c_res[8];
+    simdgroup_float8x8 mc[8];
     for (int i = 0; i < 8; i++){
-        c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
+        mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
     }
     short il = (tiitg % THREAD_PER_ROW);
 
@@ -6530,11 +5626,14 @@ void kernel_mul_mm_id_impl(
         threadgroup half  * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
         threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
 
+        #pragma unroll(BLOCK_SIZE_K/8)
         for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
+            #pragma unroll(4)
             for (int i = 0; i < 4; i++) {
                 simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
             }
             simdgroup_barrier(mem_flags::mem_none);
+            #pragma unroll(2)
             for (int i = 0; i < 2; i++) {
                 simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
             }
@@ -6542,29 +5641,42 @@ void kernel_mul_mm_id_impl(
             lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
             lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
 
+            #pragma unroll(8)
             for (int i = 0; i < 8; i++){
-                simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
+                simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
             }
         }
     }
 
     {
         threadgroup_barrier(mem_flags::mem_threadgroup);
-        threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
+        threadgroup float * temp_str = ((threadgroup float *) shmem) \
                                       + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
         for (int i = 0; i < 8; i++) {
-            simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
+            simdgroup_store(mc[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
         }
 
         threadgroup_barrier(mem_flags::mem_threadgroup);
 
-        device float * C = dst + (BLOCK_SIZE_M * r0);
         if (sgitg == 0) {
             for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
                 threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j];
-                int joff =  jid[0] * ne0 + jid[1] * ne0ne1;
-                for (int i = 0; i < n_rows; i++) {
-                    *(C + i + joff) = *(temp_str + i + j * BLOCK_SIZE_M);
+                int64_t joff = jid[0]*ne0 + jid[1]*ne0ne1;
+
+                device float  * D  = (device float  *) dst + (r0*BLOCK_SIZE_M) + joff;
+                device float4 * D4 = (device float4 *) D;
+
+                threadgroup float  * C  = temp_str + (j*BLOCK_SIZE_M);
+                threadgroup float4 * C4 = (threadgroup float4 *) C;
+
+                int i = 0;
+                for (; i < n_rows/4; i++) {
+                    *(D4 + i) = *(C4 + i);
+                }
+
+                i *= 4;
+                for (; i < n_rows; i++) {
+                    *(D + i) = *(C + i);
                 }
             }
         }
@@ -6573,48 +5685,34 @@ void kernel_mul_mm_id_impl(
 
 template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
 kernel void kernel_mul_mm_id(
-        device const   uchar * src0s,
-        device const   uchar * src1,
-        device         float * dst,
-        device const   uchar * ids,
-        constant     int64_t & nei0,
-        constant     int64_t & nei1,
-        constant    uint64_t & nbi1,
-        constant     int64_t & ne00,
-        constant     int64_t & ne02,
-        constant    uint64_t & nb01,
-        constant    uint64_t & nb02,
-        constant     int64_t & ne11,
-        constant     int64_t & ne12,
-        constant     int64_t & ne13,
-        constant    uint64_t & nb10,
-        constant    uint64_t & nb11,
-        constant    uint64_t & nb12,
-        constant     int64_t & ne0,
-        constant     int64_t & ne1,
-        constant    uint64_t & nb1,
-        threadgroup    uchar * shared_memory [[threadgroup(0)]],
-        uint3                  tgpig[[threadgroup_position_in_grid]],
-        uint                   tiitg[[thread_index_in_threadgroup]],
-        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
+        constant ggml_metal_kargs_mul_mm_id & args,
+        device const char * src0s,
+        device const char * src1,
+        device       char * dst,
+        device const char * ids,
+        threadgroup  char * shmem [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiitg[[thread_index_in_threadgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
 
     const int32_t i02 = tgpig.z;
+
     tgpig.z = 0;
 
-    device const uchar * src0 = src0s + i02*nb02;
+    device const char * src0 = src0s + i02*args.nb02;
 
     // row indices
-    threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
+    threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192);
 
     // TODO: parallelize this loop
-    int64_t _ne1 = 0;
-    for (ushort ii1 = 0; ii1 < nei1; ii1++) {
-        for (ushort ii0 = 0; ii0 < nei0; ii0++) {
-            int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
+    int32_t _ne1 = 0;
+    for (ushort ii1 = 0; ii1 < args.nei1; ii1++) {
+        for (ushort ii0 = 0; ii0 < args.nei0; ii0++) {
+            int32_t id = ((device int32_t *) (ids + ii1*args.nbi1))[ii0];
             if (id == i02) {
-                //if (tiitg == 0) {
+                if (tiitg == 0) {
                     rowids[_ne1] = ushort2(ii0, ii1);
-                //}
+                }
                 _ne1++;
             }
         }
@@ -6623,23 +5721,23 @@ kernel void kernel_mul_mm_id(
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
     kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
+        args.ne00,
+        args.ne02,
+        args.nb01,
+        args.nb02,
+        args.ne11,
+        args.ne12,
+        args.nb10,
+        args.nb11,
+        args.nb12,
+        args.ne0,
+        _ne1,
+        (int64_t)args.ne0*args.ne1,
         src0,
         src1,
         rowids,
         dst,
-        ne00,
-        ne02,
-        nb01,
-        nb02,
-        ne11,
-        ne12,
-        nb10,
-        nb11,
-        nb12,
-        ne0,
-        _ne1,
-        ne0*ne1,
-        shared_memory,
+        shmem,
         tgpig,
         tiitg,
         sgitg);
@@ -6748,194 +5846,110 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]]  kernel mat_mm_id_t kernel
 //
 
 typedef void (kernel_mul_mv_impl_t)(
-        device const  char * src0,
-        device const  char * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb00,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne11,
-                   int64_t   ne12,
-                  uint64_t   nb10,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-                   uint3     tgpig,
-                   uint      tiisg);
+        ggml_metal_kargs_mul_mv args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        uint3  tgpig,
+        ushort tiisg);
 
 typedef void (kernel_mul_mv2_impl_t)(
-        device const  void * src0,
-        device const float * src1,
-        device       float * dst,
-                   int64_t   ne00,
-                   int64_t   ne01,
-                   int64_t   ne02,
-                  uint64_t   nb01,
-                  uint64_t   nb02,
-                  uint64_t   nb03,
-                   int64_t   ne10,
-                   int64_t   ne12,
-                  uint64_t   nb11,
-                  uint64_t   nb12,
-                  uint64_t   nb13,
-                   int64_t   ne0,
-                   int64_t   ne1,
-                   uint      r2,
-                   uint      r3,
-        threadgroup int8_t * shared_values,
-                   uint3     tgpig,
-                   uint      tiisg,
-                   uint      sgitg);
+        ggml_metal_kargs_mul_mv args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg);
 
 template<kernel_mul_mv_impl_t impl_fn>
 void mmv_fn(
-        device const    char * src0,
-        device const    char * src1,
-        device         float * dst,
-                     int64_t   ne00,
-                     int64_t   ne01,
-                     int64_t   ne02,
-                    uint64_t   nb00,
-                    uint64_t   nb01,
-                    uint64_t   nb02,
-                    uint64_t   nb03,
-                     int64_t   ne10,
-                     int64_t   ne11,
-                     int64_t   ne12,
-                     int64_t   ne13,
-                    uint64_t   nb10,
-                    uint64_t   nb11,
-                    uint64_t   nb12,
-                    uint64_t   nb13,
-                     int64_t   ne0,
-                     int64_t   ne1,
-                    uint64_t   nb1,
-                        uint   r2,
-                        uint   r3,
-        threadgroup int8_t   * shared_values,
-        uint3                  tgpig,
-        uint                   tiitg,
-        uint                   tiisg,
-        uint                   sgitg) {
-    impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,nb03,ne10,ne11,ne12,nb10,nb11,nb12,nb13,ne0,ne1,r2,r3,tgpig,tiisg);
+        ggml_metal_kargs_mul_mv args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiitg,
+        ushort tiisg,
+        ushort sgitg) {
+    impl_fn(args, src0, src1, dst, tgpig, tiisg);
 }
 
 template<kernel_mul_mv2_impl_t impl_fn>
 void mmv_fn(
-        device const    char * src0,
-        device const    char * src1,
-        device         float * dst,
-                     int64_t   ne00,
-                     int64_t   ne01,
-                     int64_t   ne02,
-                    uint64_t   nb00,
-                    uint64_t   nb01,
-                    uint64_t   nb02,
-                    uint64_t   nb03,
-                     int64_t   ne10,
-                     int64_t   ne11,
-                     int64_t   ne12,
-                     int64_t   ne13,
-                    uint64_t   nb10,
-                    uint64_t   nb11,
-                    uint64_t   nb12,
-                    uint64_t   nb13,
-                     int64_t   ne0,
-                     int64_t   ne1,
-                    uint64_t   nb1,
-                        uint   r2,
-                        uint   r3,
-        threadgroup int8_t   * shared_values,
-        uint3                  tgpig,
-        uint                   tiitg,
-        uint                   tiisg,
-        uint                   sgitg) {
-    impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
-}
-
-typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4>>) mul_mv_impl_fn_t;
+        ggml_metal_kargs_mul_mv args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiitg,
+        ushort tiisg,
+        ushort sgitg) {
+    impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+}
+
+typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4, ggml_metal_kargs_mul_mv>>) mul_mv_impl_fn_t;
 
 template<mul_mv_impl_fn_t impl_fn>
 kernel void kernel_mul_mv_id(
-        device const    char * src0s,
-        device const    char * src1,
-        device         float * dst,
-        device const    char * ids,
-        constant     int64_t & nei0,
-        constant     int64_t & nei1,
-        constant    uint64_t & nbi1,
-        constant     int64_t & ne00,
-        constant     int64_t & ne01,
-        constant     int64_t & ne02,
-        constant    uint64_t & nb00,
-        constant    uint64_t & nb01,
-        constant    uint64_t & nb02,
-        constant     int64_t & ne10,
-        constant     int64_t & ne11,
-        constant     int64_t & ne12,
-        constant     int64_t & ne13,
-        constant    uint64_t & nb10,
-        constant    uint64_t & nb11,
-        constant    uint64_t & nb12,
-        constant     int64_t & ne0,
-        constant     int64_t & ne1,
-        constant    uint64_t & nb1,
-        threadgroup int8_t   * shared_values [[threadgroup(0)]],
-        uint3                  tgpig[[threadgroup_position_in_grid]],
-        uint                   tiitg[[thread_index_in_threadgroup]],
-        uint                   tiisg[[thread_index_in_simdgroup]],
-        uint                   sgitg[[simdgroup_index_in_threadgroup]]) {
-    const int iid1 = tgpig.z/nei0;
-    const int idx = tgpig.z%nei0;
+        constant ggml_metal_kargs_mul_mv_id & args,
+        device const char * src0s,
+        device const char * src1,
+        device       char * dst,
+        device const char * ids,
+        threadgroup  char * shmem [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiitg[[thread_index_in_threadgroup]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+    const int iid1 = tgpig.z/args.nei0;
+    const int idx  = tgpig.z%args.nei0;
 
     tgpig.z = 0;
 
-    const int32_t i02 = ((device const int32_t *) (ids + iid1*nbi1))[idx];
+    const int32_t i02 = ((device const int32_t *) (ids + iid1*args.nbi1))[idx];
 
-    const int64_t i11 = idx % ne11;
+    const int64_t i11 = idx % args.ne11;
     const int64_t i12 = iid1;
 
     const int64_t i1 = idx;
     const int64_t i2 = i12;
 
-    device const char * src0_cur = src0s + i02*nb02;
-    device const char * src1_cur = src1  + i11*nb11 + i12*nb12;
-    device      float *  dst_cur = dst   + i1*ne0   + i2*ne1*ne0;
+    device const char * src0_cur = src0s + i02*args.nb02;
+    device const char * src1_cur = src1  + i11*args.nb11 + i12*args.nb12;
+
+    device char * dst_cur = dst + (i1*args.ne0 + i2*args.ne1*args.ne0)*sizeof(float);
+
+    ggml_metal_kargs_mul_mv args0 = {
+        /*.ne00 =*/ args.ne00,
+        /*.ne01 =*/ args.ne01,
+        /*.ne02 =*/ 1, // args.ne02,
+        /*.nb00 =*/ args.nb00,
+        /*.nb01 =*/ args.nb01,
+        /*.nb02 =*/ args.nb02,
+        /*.nb03 =*/ args.nb02, // args.ne02 == 1
+        /*.ne10 =*/ args.ne10,
+        /*.ne11 =*/ 1, // args.ne11,
+        /*.ne12 =*/ 1, // args.ne12,
+        /*.nb10 =*/ args.nb10,
+        /*.nb11 =*/ args.nb11,
+        /*.nb12 =*/ args.nb12,
+        /*.nb13 =*/ args.nb12, // ne12 == 1
+        /*.ne0  =*/ args.ne0,
+        /*.ne1  =*/ 1, // args.ne1,
+        /*.r2   =*/ 1,
+        /*.r3   =*/ 1,
+    };
 
     impl_fn(
+        args0,
         /* src0 */ src0_cur,
         /* src1 */ src1_cur,
         /* dst  */ dst_cur,
-        /* ne00 */ ne00,
-        /* ne01 */ ne01,
-        /* ne02 */ 1, // ne02,
-        /* nb00 */ nb00,
-        /* nb01 */ nb01,
-        /* nb02 */ nb02,
-        /* nb03 */ nb02, // ne02 == 1
-        /* ne10 */ ne10,
-        /* ne11 */ 1, // ne11,
-        /* ne12 */ 1, // ne12,
-        /* ne13 */ 1, // ne13,
-        /* nb10 */ nb10,
-        /* nb11 */ nb11,
-        /* nb12 */ nb12,
-        /* ne13 */ nb12, // ne12 == 1
-        /* ne0  */ ne0,
-        /* ne1  */ 1, // ne1,
-        /* nb1  */ nb1,
-        /* r2   */ 1,
-        /* r3   */ 1,
-        shared_values,
+        shmem,
         tgpig,
         tiitg,
         tiisg,