]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : simplify kernel arguments using a struct (#3229) (#12194)
authorBB-fat <redacted>
Fri, 7 Mar 2025 07:35:57 +0000 (15:35 +0800)
committerGitHub <redacted>
Fri, 7 Mar 2025 07:35:57 +0000 (08:35 +0100)
* metal : refactor im2col parameters into a struct

* metal: Change im2col offset types from int32_t to uint64_t to support larger memory offsets

* metal : refactor sum_rows parameters into a struct

* metal : refactor soft_max parameters into a struct

* metal : refactor diag_mask_inf parameters into a struct

* metal : refactor ssm_conv parameters into a struct

* metal : refactor ssm_scan parameters into a struct

* metal : refactor get_rows parameters into a struct

* metal : refactor group_norm parameters into a struct

* metal : refactor conv_transpose_1d parameters into a struct

* metal : refactor upscale parameters into a struct

* metal : refactor pad parameters into a struct

* metal : refactor pad_reflect_1d parameters into a struct

* metal : refactor arange parameters into a struct

* metal : refactor timestep_embedding parameters into a struct

* metal : refactor argsort parameters into a struct

* metal : refactor leaky_relu parameters into a struct

* metal : refactor pool_2d parameters into a struct

* metal : fix trailing whitespace

---------

Co-authored-by: alexju <redacted>
ggml/src/ggml-metal/ggml-metal-impl.h
ggml/src/ggml-metal/ggml-metal.m
ggml/src/ggml-metal/ggml-metal.metal

index e3dc25f1686fbf352c2b72a4e84aadaae1a2d987..a58c474eb007e4ccea031a40a003808e19570d97 100644 (file)
@@ -285,4 +285,239 @@ typedef struct {
     float    eps;
 } ggml_metal_kargs_rms_norm;
 
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int64_t  ne02;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    int32_t  n_groups;
+    float    eps;
+} ggml_metal_kargs_group_norm;
+
+typedef struct {
+    int32_t  IC;
+    int32_t  IL;
+    int32_t  K;
+    int32_t  s0;
+    uint64_t nb0;
+    uint64_t nb1;
+} ggml_metal_kargs_conv_transpose_1d;
+
+typedef struct {
+    uint64_t  ofs0;
+    uint64_t  ofs1;
+    int32_t  IW;
+    int32_t  IH;
+    int32_t  CHW;
+    int32_t  s0;
+    int32_t  s1;
+    int32_t  p0;
+    int32_t  p1;
+    int32_t  d0;
+    int32_t  d1;
+    int32_t  N;
+    int32_t  KH;
+    int32_t  KW;
+    int32_t  KHW; // KH * KW, pre-computed on CPU to save GPU resources
+} ggml_metal_kargs_im2col;
+
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int64_t  ne02;
+    int64_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int64_t  ne10;
+    int64_t  ne11;
+    int64_t  ne12;
+    int64_t  ne13;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    int64_t  ne0;
+    int64_t  ne1;
+    int64_t  ne2;
+    int64_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+} ggml_metal_kargs_sum_rows;
+
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int64_t  ne02;
+    float    scale;
+    float    max_bias;
+    float    m0;
+    float    m1;
+    uint32_t n_head_log2;
+} ggml_metal_kargs_soft_max;
+
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int      n_past;
+} ggml_metal_kargs_diag_mask_inf;
+
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int64_t  ne02;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    int64_t  ne10;
+    int64_t  ne11;
+    uint64_t nb10;
+    uint64_t nb11;
+    int64_t  ne0;
+    int64_t  ne1;
+    int64_t  ne2;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+} ggml_metal_kargs_ssm_conv;
+
+typedef struct {
+    int64_t  d_state;
+    int64_t  d_inner;
+    int64_t  n_seq_tokens;
+    int64_t  n_seqs;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    uint64_t nb20;
+    uint64_t nb21;
+    uint64_t nb22;
+    uint64_t nb30;
+    uint64_t nb31;
+    uint64_t nb40;
+    uint64_t nb41;
+    uint64_t nb42;
+    uint64_t nb50;
+    uint64_t nb51;
+    uint64_t nb52;
+} ggml_metal_kargs_ssm_scan;
+
+typedef struct {
+    int64_t  ne00;
+    uint64_t nb01;
+    uint64_t nb02;
+    int64_t  ne10;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb1;
+    uint64_t nb2;
+} ggml_metal_kargs_get_rows;
+
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int64_t  ne02;
+    int64_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int64_t  ne0;
+    int64_t  ne1;
+    int64_t  ne2;
+    int64_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+    float    sf0;
+    float    sf1;
+    float    sf2;
+    float    sf3;
+} ggml_metal_kargs_upscale;
+
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int64_t  ne02;
+    int64_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int64_t  ne0;
+    int64_t  ne1;
+    int64_t  ne2;
+    int64_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+} ggml_metal_kargs_pad;
+
+typedef struct {
+    int64_t  ne00;
+    int64_t  ne01;
+    int64_t  ne02;
+    int64_t  ne03;
+    uint64_t nb00;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int64_t  ne0;
+    int64_t  ne1;
+    int64_t  ne2;
+    int64_t  ne3;
+    uint64_t nb0;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+    int32_t  p0;
+    int32_t  p1;
+} ggml_metal_kargs_pad_reflect_1d;
+
+typedef struct {
+    uint64_t nb1;
+    int      dim;
+    int      max_period;
+} ggml_metal_kargs_timestep_embedding;
+
+typedef struct {
+    float    slope;
+} ggml_metal_kargs_leaky_relu;
+
+typedef struct {
+    int64_t  ncols;
+    int64_t  ncols_pad;
+} ggml_metal_kargs_argsort;
+
+typedef struct {
+    int64_t  ne0;
+    float    start;
+    float    step;
+} ggml_metal_kargs_arange;
+
+typedef struct {
+    int32_t  k0;
+    int32_t  k1;
+    int32_t  s0;
+    int32_t  s1;
+    int32_t  p0;
+    int32_t  p1;
+    int64_t  IH;
+    int64_t  IW;
+    int64_t  OH;
+    int64_t  OW;
+    int64_t  parallel_elements;
+} ggml_metal_kargs_pool_2d;
+
 #endif // GGML_METAL_IMPL
index 1f45ebad146df57ea77f7e93c6824c0fb039df7f..1158b285c19bca554d7eadf7d6d7f3c5bfc8d151 100644 (file)
@@ -1945,34 +1945,38 @@ static void ggml_metal_encode_node(
 
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
 
-                // TODO: add ggml_metal_kargs struct
+
+                ggml_metal_kargs_sum_rows args = {
+                   /*.ne00 =*/ ne00,
+                   /*.ne01 =*/ ne01,
+                   /*.ne02 =*/ ne02,
+                   /*.ne03 =*/ ne03,
+                   /*.nb00 =*/ nb00,
+                   /*.nb01 =*/ nb01,
+                   /*.nb02 =*/ nb02,
+                   /*.nb03 =*/ nb03,
+                   /*.ne10 =*/ ne10,
+                   /*.ne11 =*/ ne11,
+                   /*.ne12 =*/ ne12,
+                   /*.ne13 =*/ ne13,
+                   /*.nb10 =*/ nb10,
+                   /*.nb11 =*/ nb11,
+                   /*.nb12 =*/ nb12,
+                   /*.nb13 =*/ nb13,
+                   /*.ne0  =*/ ne0,
+                   /*.ne1  =*/ ne1,
+                   /*.ne2  =*/ ne2,
+                   /*.ne3  =*/ ne3,
+                   /*.nb0  =*/ nb0,
+                   /*.nb1  =*/ nb1,
+                   /*.nb2  =*/ nb2,
+                   /*.nb3  =*/ nb3,
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
-                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
-                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
-                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
-                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
-                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
-                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
-                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
-                [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
-                [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
-                [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
-                [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
-                [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
-                [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
-                [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
-                [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
-                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:18];
-                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:19];
-                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:20];
-                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:21];
-                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:22];
-                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:23];
-                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:24];
-                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:25];
+                [encoder setBytes:&args length:sizeof(args) atIndex:2];
 
                 [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
             } break;
@@ -2021,8 +2025,17 @@ static void ggml_metal_encode_node(
                 const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
                 const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
-                // TODO: add ggml_metal_kargs struct
-                // TODO: optimize (see https://github.com/ggml-org/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6)
+                ggml_metal_kargs_soft_max args = {
+                    /*.ne00 =*/ ne00,
+                    /*.ne01 =*/ ne01,
+                    /*.ne02 =*/ ne02,
+                    /*.scale =*/ scale,
+                    /*.max_bias =*/ max_bias,
+                    /*.m0 =*/ m0,
+                    /*.m1 =*/ m1,
+                    /*.n_head_log2 =*/ n_head_log2,
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0   atIndex:0];
                 if (id_src1) {
@@ -2031,14 +2044,7 @@ static void ggml_metal_encode_node(
                     [encoder setBuffer:id_src0 offset:offs_src0   atIndex:1];
                 }
                 [encoder setBuffer:id_dst      offset:offs_dst            atIndex:2];
-                [encoder setBytes:&ne00        length:sizeof(ne00)        atIndex:3];
-                [encoder setBytes:&ne01        length:sizeof(ne01)        atIndex:4];
-                [encoder setBytes:&ne02        length:sizeof(ne02)        atIndex:5];
-                [encoder setBytes:&scale       length:sizeof(scale)       atIndex:6];
-                [encoder setBytes:&max_bias    length:sizeof(max_bias)    atIndex:7];
-                [encoder setBytes:&m0          length:sizeof(m0)          atIndex:8];
-                [encoder setBytes:&m1          length:sizeof(m1)          atIndex:9];
-                [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
+                [encoder setBytes:&args        length:sizeof(args)        atIndex:3];
 
                 [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
 
@@ -2056,13 +2062,16 @@ static void ggml_metal_encode_node(
                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
                 }
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_diag_mask_inf args = {
+                    /*.ne00 =*/ ne00,
+                    /*.ne01 =*/ ne01,
+                    /*.n_past =*/ n_past,
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                [encoder setBytes:&ne00   length:sizeof(ne00) atIndex:2];
-                [encoder setBytes:&ne01   length:sizeof(ne01) atIndex:3];
-                [encoder setBytes:&n_past length:sizeof(int)  atIndex:4];
+                [encoder setBytes:&args  length:sizeof(args) atIndex:2];
 
                 if (ne00%8 == 0) {
                     [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
@@ -2081,27 +2090,30 @@ static void ggml_metal_encode_node(
 
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_ssm_conv args = {
+                    /*.ne00 =*/ ne00,
+                    /*.ne01 =*/ ne01,
+                    /*.ne02 =*/ ne02,
+                    /*.nb00 =*/ nb00,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.ne10 =*/ ne10,
+                    /*.ne11 =*/ ne11,
+                    /*.nb10 =*/ nb10,
+                    /*.nb11 =*/ nb11,
+                    /*.ne0  =*/ ne0,
+                    /*.ne1  =*/ ne1,
+                    /*.ne2  =*/ ne2,
+                    /*.nb0  =*/ nb0,
+                    /*.nb1  =*/ nb1,
+                    /*.nb2  =*/ nb2,
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0    atIndex:0];
                 [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1];
                 [encoder setBuffer:id_dst  offset:offs_dst     atIndex:2];
-                [encoder setBytes:&ne00    length:sizeof(ne00) atIndex:3];
-                [encoder setBytes:&ne01    length:sizeof(ne01) atIndex:4];
-                [encoder setBytes:&ne02    length:sizeof(ne02) atIndex:5];
-                [encoder setBytes:&nb00    length:sizeof(nb00) atIndex:6];
-                [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:7];
-                [encoder setBytes:&nb02    length:sizeof(nb02) atIndex:8];
-                [encoder setBytes:&ne10    length:sizeof(ne10) atIndex:9];
-                [encoder setBytes:&ne11    length:sizeof(ne11) atIndex:10];
-                [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:11];
-                [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:12];
-                [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:13];
-                [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:14];
-                [encoder setBytes:&ne2     length:sizeof(ne2)  atIndex:15];
-                [encoder setBytes:&nb0     length:sizeof(nb0)  atIndex:16];
-                [encoder setBytes:&nb1     length:sizeof(nb1)  atIndex:17];
-                [encoder setBytes:&nb2     length:sizeof(nb2)  atIndex:18];
+                [encoder setBytes:&args    length:sizeof(args) atIndex:3];
 
                 [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
             } break;
@@ -2152,7 +2164,31 @@ static void ggml_metal_encode_node(
 
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_ssm_scan args = {
+                    /*.d_state =*/ d_state,
+                    /*.d_inner =*/ d_inner,
+                    /*.n_seq_tokens =*/ n_seq_tokens,
+                    /*.n_seqs =*/ n_seqs,
+                    /*.nb00 =*/ nb00,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.nb10 =*/ nb10,
+                    /*.nb11 =*/ nb11,
+                    /*.nb12 =*/ nb12,
+                    /*.nb13 =*/ nb13,
+                    /*.nb20 =*/ nb20,
+                    /*.nb21 =*/ nb21,
+                    /*.nb22 =*/ nb22,
+                    /*.nb30 =*/ nb30,
+                    /*.nb31 =*/ nb31,
+                    /*.nb40 =*/ nb40,
+                    /*.nb41 =*/ nb41,
+                    /*.nb42 =*/ nb42,
+                    /*.nb50 =*/ nb50,
+                    /*.nb51 =*/ nb51,
+                    /*.nb52 =*/ nb52,
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -2161,30 +2197,7 @@ static void ggml_metal_encode_node(
                 [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
                 [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:6];
-
-                [encoder setBytes:&d_state      length:sizeof(d_state)      atIndex:7];
-                [encoder setBytes:&d_inner      length:sizeof(d_inner)      atIndex:8];
-                [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9];
-                [encoder setBytes:&n_seqs       length:sizeof(n_seqs)       atIndex:10];
-
-                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11];
-                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12];
-                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13];
-                [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
-                [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
-                [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
-                [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
-                [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18];
-                [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19];
-                [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20];
-                [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21];
-                [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22];
-                [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23];
-                [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24];
-                [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25];
-                [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26];
-                [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27];
-                [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28];
+                [encoder setBytes:&args    length:sizeof(args) atIndex:7];
 
                 [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
             } break;
@@ -3041,19 +3054,22 @@ static void ggml_metal_encode_node(
                     default: GGML_ABORT("not implemented");
                 }
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_get_rows args = {
+                    /*.ne00 =*/ ne00,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.ne10 =*/ ne10,
+                    /*.nb10 =*/ nb10,
+                    /*.nb11 =*/ nb11,
+                    /*.nb1 =*/ nb1,
+                    /*.nb2 =*/ nb2,
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0     offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_src1     offset:offs_src1 atIndex:1];
                 [encoder setBuffer:id_dst      offset:offs_dst  atIndex:2];
-                [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
-                [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
-                [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
-                [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
-                [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
-                [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
-                [encoder setBytes:&nb1  length:sizeof(uint64_t) atIndex:9];
-                [encoder setBytes:&nb2  length:sizeof(uint64_t) atIndex:10];
+                [encoder setBytes:&args length:sizeof(args) atIndex:3];
 
                 [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
             } break;
@@ -3110,18 +3126,21 @@ static void ggml_metal_encode_node(
 
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_group_norm args = {
+                    /*.ne00 =*/ ne00,
+                    /*.ne01 =*/ ne01,
+                    /*.ne02 =*/ ne02,
+                    /*.nb00 =*/ nb00,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.n_groups =*/ n_groups,
+                    /*.eps =*/ eps,
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0  offset:offs_src0        atIndex:0];
                 [encoder setBuffer:id_dst   offset:offs_dst         atIndex:1];
-                [encoder setBytes:&ne00     length:sizeof( int64_t) atIndex:2];
-                [encoder setBytes:&ne01     length:sizeof( int64_t) atIndex:3];
-                [encoder setBytes:&ne02     length:sizeof( int64_t) atIndex:4];
-                [encoder setBytes:&nb00     length:sizeof(uint64_t) atIndex:5];
-                [encoder setBytes:&nb01     length:sizeof(uint64_t) atIndex:6];
-                [encoder setBytes:&nb02     length:sizeof(uint64_t) atIndex:7];
-                [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
-                [encoder setBytes:&eps      length:sizeof(   float) atIndex:9];
+                [encoder setBytes:&args     length:sizeof(args)     atIndex:2];
                 [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
 
                 [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
@@ -3279,8 +3298,8 @@ static void ggml_metal_encode_node(
 
                 const int32_t CHW = IC * KH * KW;
 
-                const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
-                const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
+                const uint64_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
+                const uint64_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
 
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline;
 
@@ -3302,27 +3321,30 @@ static void ggml_metal_encode_node(
                     default: GGML_ABORT("fatal error");
                 };
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_im2col args = {
+                    /*.ofs0 =*/ ofs0,
+                    /*.ofs1 =*/ ofs1,
+                    /*.IW   =*/ IW,
+                    /*.IH   =*/ IH,
+                    /*.CHW  =*/ CHW,
+                    /*.s0   =*/ s0,
+                    /*.s1   =*/ s1,
+                    /*.p0   =*/ p0,
+                    /*.p1   =*/ p1,
+                    /*.d0   =*/ d0,
+                    /*.d1   =*/ d1,
+                    /*.N    =*/ N,
+                    /*.KH   =*/ KH,
+                    /*.KW   =*/ KW,
+                    /*.KHW  =*/ KH * KW,
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src1 offset:offs_src1       atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst        atIndex:1];
-                [encoder setBytes:&ofs0    length:sizeof(int32_t) atIndex:2];
-                [encoder setBytes:&ofs1    length:sizeof(int32_t) atIndex:3];
-                [encoder setBytes:&IW      length:sizeof(int32_t) atIndex:4];
-                [encoder setBytes:&IH      length:sizeof(int32_t) atIndex:5];
-                [encoder setBytes:&CHW     length:sizeof(int32_t) atIndex:6];
-                [encoder setBytes:&s0      length:sizeof(int32_t) atIndex:7];
-                [encoder setBytes:&s1      length:sizeof(int32_t) atIndex:8];
-                [encoder setBytes:&p0      length:sizeof(int32_t) atIndex:9];
-                [encoder setBytes:&p1      length:sizeof(int32_t) atIndex:10];
-                [encoder setBytes:&d0      length:sizeof(int32_t) atIndex:11];
-                [encoder setBytes:&d1      length:sizeof(int32_t) atIndex:12];
+                [encoder setBytes:&args length:sizeof(args)       atIndex:2];
 
                 if (is_gt_mttpt) {
-                    [encoder setBytes:&N   length:sizeof(int32_t) atIndex:13];
-                    [encoder setBytes:&KH  length:sizeof(int32_t) atIndex:14];
-                    [encoder setBytes:&KW  length:sizeof(int32_t) atIndex:15];
-
                     const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N);
 
                     const int64_t  quotient  = N / n_threads + (N % n_threads > 0 ? 1 : 0);
@@ -3362,16 +3384,20 @@ static void ggml_metal_encode_node(
                     default: GGML_ABORT("fatal error");
                 };
 
+                ggml_metal_kargs_conv_transpose_1d args = {
+                    /*.IC =*/ IC,
+                    /*.IL =*/ IL,
+                    /*.K  =*/ K,
+                    /*.s0 =*/ s0,
+                    /*.nb0 =*/ nb0,
+                    /*.nb1 =*/ nb1,
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0         atIndex:0];
                 [encoder setBuffer:id_src1 offset:offs_src1         atIndex:1];
                 [encoder setBuffer:id_dst  offset:offs_dst          atIndex:2];
-                [encoder setBytes:&IC      length:sizeof( int32_t)  atIndex:3];
-                [encoder setBytes:&IL      length:sizeof( int32_t)  atIndex:4];
-                [encoder setBytes:&K       length:sizeof( int32_t)  atIndex:5];
-                [encoder setBytes:&s0      length:sizeof( int32_t)  atIndex:6];
-                [encoder setBytes:&nb0     length:sizeof(uint64_t)  atIndex:7];
-                [encoder setBytes:&nb1     length:sizeof(uint64_t)  atIndex:8];
+                [encoder setBytes:&args    length:sizeof(args)       atIndex:3];
 
                 [encoder dispatchThreadgroups:MTLSizeMake(OL, OC, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
             } break;
@@ -3386,30 +3412,33 @@ static void ggml_metal_encode_node(
 
                 const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_upscale args = {
+                    /*.ne00 =*/ ne00,
+                    /*.ne01 =*/ ne01,
+                    /*.ne02 =*/ ne02,
+                    /*.ne03 =*/ ne03,
+                    /*.nb00 =*/ nb00,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.nb03 =*/ nb03,
+                    /*.ne0 =*/ ne0,
+                    /*.ne1 =*/ ne1,
+                    /*.ne2 =*/ ne2,
+                    /*.ne3 =*/ ne3,
+                    /*.nb0 =*/ nb0,
+                    /*.nb1 =*/ nb1,
+                    /*.nb2 =*/ nb2,
+                    /*.nb3 =*/ nb3,
+                    /*.sf0 =*/ sf0,
+                    /*.sf1 =*/ sf1,
+                    /*.sf2 =*/ sf2,
+                    /*.sf3 =*/ sf3
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
-                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
-                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
-                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
-                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
-                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
-                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
-                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
-                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:10];
-                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:11];
-                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:12];
-                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:13];
-                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:14];
-                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:15];
-                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:16];
-                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:17];
-                [encoder setBytes:&sf0  length:sizeof(sf0)  atIndex:18];
-                [encoder setBytes:&sf1  length:sizeof(sf1)  atIndex:19];
-                [encoder setBytes:&sf2  length:sizeof(sf2)  atIndex:20];
-                [encoder setBytes:&sf3  length:sizeof(sf3)  atIndex:21];
+                [encoder setBytes:&args length:sizeof(args) atIndex:2];
 
                 const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
 
@@ -3421,26 +3450,29 @@ static void ggml_metal_encode_node(
 
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_pad args = {
+                    /*.ne00 =*/ ne00,
+                    /*.ne01 =*/ ne01,
+                    /*.ne02 =*/ ne02,
+                    /*.ne03 =*/ ne03,
+                    /*.nb00 =*/ nb00,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.nb03 =*/ nb03,
+                    /*.ne0 =*/ ne0,
+                    /*.ne1 =*/ ne1,
+                    /*.ne2 =*/ ne2,
+                    /*.ne3 =*/ ne3,
+                    /*.nb0 =*/ nb0,
+                    /*.nb1 =*/ nb1,
+                    /*.nb2 =*/ nb2,
+                    /*.nb3 =*/ nb3
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
-                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
-                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
-                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
-                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
-                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
-                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
-                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
-                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:10];
-                [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:11];
-                [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:12];
-                [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:13];
-                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:14];
-                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:15];
-                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:16];
-                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:17];
+                [encoder setBytes:&args length:sizeof(args) atIndex:2];
 
                 const int nth = MIN(1024, ne0);
 
@@ -3455,24 +3487,31 @@ static void ggml_metal_encode_node(
 
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_REFLECT_1D_F32].pipeline;
 
+                ggml_metal_kargs_pad_reflect_1d args = {
+                    /*.ne00 =*/ ne00,
+                    /*.ne01 =*/ ne01,
+                    /*.ne02 =*/ ne02,
+                    /*.ne03 =*/ ne03,
+                    /*.nb00 =*/ nb00,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.nb03 =*/ nb03,
+                    /*.ne0 =*/ ne0,
+                    /*.ne1 =*/ ne1,
+                    /*.ne2 =*/ ne2,
+                    /*.ne3 =*/ ne3,
+                    /*.nb0 =*/ nb0,
+                    /*.nb1 =*/ nb1,
+                    /*.nb2 =*/ nb2,
+                    /*.nb3 =*/ nb3,
+                    /*.p0 =*/ p0,
+                    /*.p1 =*/ p1
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
-                [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
-                [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
-                [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
-                [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:6];
-                [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
-                [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
-                [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
-                [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
-                [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:11];
-                [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:12];
-                [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:13];
-                [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:14];
-                [encoder setBytes:&p0   length:sizeof(p0)   atIndex:15];
-                [encoder setBytes:&p1   length:sizeof(p1)   atIndex:16];
+                [encoder setBytes:&args length:sizeof(args) atIndex:2];
 
                 const int nth = MIN(1024, ne0);
 
@@ -3490,12 +3529,15 @@ static void ggml_metal_encode_node(
 
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_arange args = {
+                    /*.ne0 =*/ ne0,
+                    /*.start =*/ start,
+                    /*.step =*/ step
+                };
+
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_dst  offset:offs_dst    atIndex:0];
-                [encoder setBytes:&ne0   length:sizeof(ne0)   atIndex:1];
-                [encoder setBytes:&start length:sizeof(start) atIndex:2];
-                [encoder setBytes:&step  length:sizeof(step)  atIndex:3];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:0];
+                [encoder setBytes:&args length:sizeof(args) atIndex:1];
 
                 const int nth = MIN(1024, ne0);
 
@@ -3512,13 +3554,16 @@ static void ggml_metal_encode_node(
 
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline;
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_timestep_embedding args = {
+                    /*.nb1 =*/ nb1,
+                    /*.dim =*/ dim,
+                    /*.max_period =*/ max_period
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                [encoder setBytes:&nb1   length:sizeof(nb1) atIndex:2];
-                [encoder setBytes:&dim   length:sizeof(dim) atIndex:3];
-                [encoder setBytes:&max_period length:sizeof(max_period) atIndex:4];
+                [encoder setBytes:&args length:sizeof(args) atIndex:2];
 
                 const int nth = MIN(1024, half);
 
@@ -3551,12 +3596,15 @@ static void ggml_metal_encode_node(
                     default: GGML_ABORT("fatal error");
                 };
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_argsort args = {
+                    /*.ncols =*/ ne00,
+                    /*.ncols_pad =*/ ne00_padded
+                };
+
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0     offset:offs_src0        atIndex:0];
-                [encoder setBuffer:id_dst      offset:offs_dst         atIndex:1];
-                [encoder setBytes:&ne00        length:sizeof( int64_t) atIndex:2];
-                [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                [encoder setBytes:&args length:sizeof(args) atIndex:2];
                 [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
 
                 [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];
@@ -3570,11 +3618,14 @@ static void ggml_metal_encode_node(
 
                 id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_leaky_relu args = {
+                    /*.slope =*/ slope
+                };
+
                 [encoder setComputePipelineState:pipeline];
                 [encoder setBuffer:id_src0 offset:offs_src0   atIndex:0];
                 [encoder setBuffer:id_dst  offset:offs_dst    atIndex:1];
-                [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
+                [encoder setBytes:&args length:sizeof(args)   atIndex:2];
 
                 const int64_t n = ggml_nelements(dst);
 
@@ -4150,21 +4201,24 @@ static void ggml_metal_encode_node(
                 const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
                 const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
 
-                // TODO: add ggml_metal_kargs struct
+                ggml_metal_kargs_pool_2d args_pool_2d = {
+                    /* .k0 = */ k0,
+                    /* .k1 = */ k1,
+                    /* .s0 = */ s0,
+                    /* .s1 = */ s1,
+                    /* .p0 = */ p0,
+                    /* .p1 = */ p1,
+                    /* .IH = */ IH,
+                    /* .IW = */ IW,
+                    /* .OH = */ OH,
+                    /* .OW = */ OW,
+                    /* .parallel_elements = */ parallel_elements
+                };
+
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0       atIndex:0];
-                [encoder setBuffer:id_dst  offset:offs_dst        atIndex:1];
-                [encoder setBytes:&k0      length:sizeof(int32_t) atIndex:2];
-                [encoder setBytes:&k1      length:sizeof(int32_t) atIndex:3];
-                [encoder setBytes:&s0      length:sizeof(int32_t) atIndex:4];
-                [encoder setBytes:&s1      length:sizeof(int32_t) atIndex:5];
-                [encoder setBytes:&p0      length:sizeof(int32_t) atIndex:6];
-                [encoder setBytes:&p1      length:sizeof(int32_t) atIndex:7];
-                [encoder setBytes:&IH      length:sizeof(int64_t) atIndex:8];
-                [encoder setBytes:&IW      length:sizeof(int64_t) atIndex:9];
-                [encoder setBytes:&OH      length:sizeof(int64_t) atIndex:10];
-                [encoder setBytes:&OW      length:sizeof(int64_t) atIndex:11];
-                [encoder setBytes:&parallel_elements length:sizeof(int64_t) atIndex:12];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+                [encoder setBytes:&args_pool_2d length:sizeof(args_pool_2d) atIndex:2];
 
                 [encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
             } break;
index c46a13050891fd3fbf3ce7015e7d3f1d54f4b8cf..ad9d42a3eaa9e75497b95e03814b95efd7a3f0b7 100644 (file)
@@ -947,45 +947,22 @@ kernel void kernel_cos(
 kernel void kernel_sum_rows(
         device const float * src0,
         device       float * dst,
-        constant  int64_t & ne00,
-        constant  int64_t & ne01,
-        constant  int64_t & ne02,
-        constant  int64_t & ne03,
-        constant uint64_t & nb00,
-        constant uint64_t & nb01,
-        constant uint64_t & nb02,
-        constant uint64_t & nb03,
-        constant  int64_t & ne10,
-        constant  int64_t & ne11,
-        constant  int64_t & ne12,
-        constant  int64_t & ne13,
-        constant uint64_t & nb10,
-        constant uint64_t & nb11,
-        constant uint64_t & nb12,
-        constant uint64_t & nb13,
-        constant  int64_t & ne0,
-        constant  int64_t & ne1,
-        constant  int64_t & ne2,
-        constant  int64_t & ne3,
-        constant uint64_t & nb0,
-        constant uint64_t & nb1,
-        constant uint64_t & nb2,
-        constant uint64_t & nb3,
+        constant ggml_metal_kargs_sum_rows & args,
         uint3 tpig[[thread_position_in_grid]]) {
     int64_t i3 = tpig.z;
     int64_t i2 = tpig.y;
     int64_t i1 = tpig.x;
 
-    if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
+    if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
         return;
     }
 
-    device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
-    device       float * dst_row = (device       float *) ((device       char *) dst  + i1*nb1  + i2*nb2  + i3*nb3);
+    device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03);
+    device       float * dst_row = (device       float *) ((device       char *) dst  + i1*args.nb1  + i2*args.nb2  + i3*args.nb3);
 
     float row_sum = 0;
 
-    for (int64_t i0 = 0; i0 < ne00; i0++) {
+    for (int64_t i0 = 0; i0 < args.ne00; i0++) {
         row_sum += src_row[i0];
     }
 
@@ -997,36 +974,29 @@ kernel void kernel_soft_max(
         device const  char * src0,
         device const  char * src1,
         device        char * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant     float & scale,
-        constant     float & max_bias,
-        constant     float & m0,
-        constant     float & m1,
-        constant  uint32_t & n_head_log2,
+        constant ggml_metal_kargs_soft_max & args,
         threadgroup  float * buf [[threadgroup(0)]],
         uint  tgpig[[threadgroup_position_in_grid]],
         uint  tpitg[[thread_position_in_threadgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint    ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = (tgpig) / (ne02*ne01);
-    const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
-    const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
+    const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
+    const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
+    const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
 
-    device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
-    device const     T * pmask = src1 != src0 ? (device const    T *) src1         + i01*ne00 : nullptr;
-    device       float * pdst  = (device       float *) dst  + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+    device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
+    device const     T * pmask = src1 != src0 ? (device const    T *) src1         + i01*args.ne00 : nullptr;
+    device       float * pdst  = (device       float *) dst  + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
 
     float slope = 1.0f;
 
     // ALiBi
-    if (max_bias > 0.0f) {
+    if (args.max_bias > 0.0f) {
         const int64_t h = i02;
 
-        const float base = h < n_head_log2 ? m0 : m1;
-        const int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+        const float base = h < args.n_head_log2 ? args.m0 : args.m1;
+        const int   exp  = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
 
         slope = pow(base, exp);
     }
@@ -1034,8 +1004,8 @@ kernel void kernel_soft_max(
     // parallel max
     float lmax = -INFINITY;
 
-    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
-        lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
+    for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
+        lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
     }
 
     // find the max value in the block
@@ -1059,8 +1029,8 @@ kernel void kernel_soft_max(
 
     // parallel sum
     float lsum = 0.0f;
-    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
-        const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
+    for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
+        const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
         lsum += exp_psrc0;
         pdst[i00] = exp_psrc0;
     }
@@ -1090,7 +1060,7 @@ kernel void kernel_soft_max(
 
     const float inv_sum = 1.0f/sum;
 
-    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+    for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
         pdst[i00] *= inv_sum;
     }
 }
@@ -1100,35 +1070,28 @@ kernel void kernel_soft_max_4(
         device const  char * src0,
         device const  char * src1,
         device        char * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant     float & scale,
-        constant     float & max_bias,
-        constant     float & m0,
-        constant     float & m1,
-        constant  uint32_t & n_head_log2,
+        constant ggml_metal_kargs_soft_max & args,
         threadgroup  float * buf [[threadgroup(0)]],
         uint  tgpig[[threadgroup_position_in_grid]],
         uint  tpitg[[thread_position_in_threadgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint    ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = (tgpig) / (ne02*ne01);
-    const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
-    const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
+    const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
+    const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
+    const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
 
-    device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
-    device const      T * pmask = src1 != src0 ? (device const     T *) src1         + i01*ne00/4 : nullptr;
-    device       float4 * pdst4 = (device       float4 *) dst  + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
+    device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
+    device const      T * pmask = src1 != src0 ? (device const     T *) src1         + i01*args.ne00/4 : nullptr;
+    device       float4 * pdst4 = (device       float4 *) dst  + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
 
     float slope = 1.0f;
 
-    if (max_bias > 0.0f) {
+    if (args.max_bias > 0.0f) {
         const int64_t h = i02;
 
-        const float base = h < n_head_log2 ? m0 : m1;
-        const int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+        const float base = h < args.n_head_log2 ? args.m0 : args.m1;
+        const int   exp  = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
 
         slope = pow(base, exp);
     }
@@ -1136,8 +1099,8 @@ kernel void kernel_soft_max_4(
     // parallel max
     float4 lmax4 = -INFINITY;
 
-    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
-        lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
+    for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
+        lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
     }
 
     const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
@@ -1162,8 +1125,8 @@ kernel void kernel_soft_max_4(
 
     // parallel sum
     float4 lsum4 = 0.0f;
-    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
-        const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
+    for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
+        const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
         lsum4 += exp_psrc4;
         pdst4[i00] = exp_psrc4;
     }
@@ -1195,7 +1158,7 @@ kernel void kernel_soft_max_4(
 
     const float inv_sum = 1.0f/sum;
 
-    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+    for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
         pdst4[i00] *= inv_sum;
     }
 }
@@ -1211,27 +1174,23 @@ template [[host_name("kernel_soft_max_f32_4")]] kernel kernel_soft_max_4_t kerne
 kernel void kernel_diag_mask_inf(
         device const float * src0,
         device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant       int & n_past,
+        constant ggml_metal_kargs_diag_mask_inf & args,
         uint3 tpig[[thread_position_in_grid]]) {
     const int64_t i02 = tpig[2];
     const int64_t i01 = tpig[1];
     const int64_t i00 = tpig[0];
 
-    if (i00 > n_past + i01) {
-        dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
+    if (i00 > args.n_past + i01) {
+        dst[i02*args.ne01*args.ne00 + i01*args.ne00 + i00] = -INFINITY;
     } else {
-        dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
+        dst[i02*args.ne01*args.ne00 + i01*args.ne00 + i00] = src0[i02*args.ne01*args.ne00 + i01*args.ne00 + i00];
     }
 }
 
 kernel void kernel_diag_mask_inf_8(
         device const float4 * src0,
         device       float4 * dst,
-        constant    int64_t & ne00,
-        constant    int64_t & ne01,
-        constant        int & n_past,
+        constant ggml_metal_kargs_diag_mask_inf & args,
         uint3 tpig[[thread_position_in_grid]]) {
 
     const int64_t i = 2*tpig[0];
@@ -1239,42 +1198,26 @@ kernel void kernel_diag_mask_inf_8(
     dst[i+0] = src0[i+0];
     dst[i+1] = src0[i+1];
     int64_t i4 = 4*i;
-    const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
-    const int64_t i01 = i4/(ne00);      i4 -= i01*ne00;
+    const int64_t i02 = i4/(args.ne00*args.ne01); i4 -= i02*args.ne00*args.ne01;
+    const int64_t i01 = i4/(args.ne00);      i4 -= i01*args.ne00;
     const int64_t i00 = i4;
     for (int k = 3; k >= 0; --k) {
-        if (i00 + 4 + k <= n_past + i01) {
+        if (i00 + 4 + k <= args.n_past + i01) {
             break;
         }
         dst[i+1][k] = -INFINITY;
-        if (i00 + k > n_past + i01) {
+        if (i00 + k > args.n_past + i01) {
             dst[i][k] = -INFINITY;
         }
     }
 }
 
 // ref: ggml.c:ggml_compute_forward_ssm_conv_f32
-// TODO: optimize
 kernel void kernel_ssm_conv_f32(
         device const  void * src0,
         device const  void * src1,
         device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant   int64_t & ne10,
-        constant   int64_t & ne11,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   int64_t & ne2,
-        constant  uint64_t & nb0,
-        constant  uint64_t & nb1,
-        constant  uint64_t & nb2,
+        constant ggml_metal_kargs_ssm_conv & args,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3 tpitg[[thread_position_in_threadgroup]],
         uint3   ntg[[threads_per_threadgroup]]) {
@@ -1282,15 +1225,15 @@ kernel void kernel_ssm_conv_f32(
     const int64_t i2 = tgpig.y;
     const int64_t i3 = tgpig.z;
 
-    const int64_t nc  = ne10;
-  //const int64_t ncs = ne00;
-  //const int64_t nr  = ne01;
-  //const int64_t n_t = ne1;
-  //const int64_t n_s = ne2;
+    const int64_t nc  = args.ne10;
+  //const int64_t ncs = args.ne00;
+  //const int64_t nr  = args.ne01;
+  //const int64_t n_t = args.ne1;
+  //const int64_t n_s = args.ne2;
 
-    device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02);
-    device const float * c = (device const float *) ((device const char *) src1 + ir*nb11);
-    device       float * x = (device       float *) ((device       char *) dst  + ir*nb0  + i2*nb1  + i3*nb2);
+    device const float * s = (device const float *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
+    device const float * c = (device const float *) ((device const char *) src1 + ir*args.nb11);
+    device       float * x = (device       float *) ((device       char *) dst  + ir*args.nb0  + i2*args.nb1  + i3*args.nb2);
 
     float sumf = 0.0f;
 
@@ -1302,7 +1245,6 @@ kernel void kernel_ssm_conv_f32(
 }
 
 // ref: ggml.c:ggml_compute_forward_ssm_scan_f32
-// TODO: optimize
 kernel void kernel_ssm_scan_f32(
         device const void * src0,
         device const void * src1,
@@ -1311,48 +1253,27 @@ kernel void kernel_ssm_scan_f32(
         device const void * src4,
         device const void * src5,
         device      float * dst,
-        constant  int64_t & d_state,
-        constant  int64_t & d_inner,
-        constant  int64_t & n_seq_tokens,
-        constant  int64_t & n_seqs,
-        constant uint64_t & nb00,
-        constant uint64_t & nb01,
-        constant uint64_t & nb02,
-        constant uint64_t & nb10,
-        constant uint64_t & nb11,
-        constant uint64_t & nb12,
-        constant uint64_t & nb13,
-        constant uint64_t & nb20,
-        constant uint64_t & nb21,
-        constant uint64_t & nb22,
-        constant uint64_t & nb30,
-        constant uint64_t & nb31,
-        constant uint64_t & nb40,
-        constant uint64_t & nb41,
-        constant uint64_t & nb42,
-        constant uint64_t & nb50,
-        constant uint64_t & nb51,
-        constant uint64_t & nb52,
+        constant ggml_metal_kargs_ssm_scan & args,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3 tpitg[[thread_position_in_threadgroup]],
         uint3   ntg[[threads_per_threadgroup]]) {
     const int64_t ir = tgpig.x;
     const int64_t i3 = tgpig.y;
 
-    const int64_t nc  = d_state;
-  //const int64_t nr  = d_inner;
-    const int64_t n_t = n_seq_tokens;
-  //const int64_t n_s = n_seqs;
+    const int64_t nc  = args.d_state;
+    // const int64_t nr  = args.d_inner;
+    const int64_t n_t = args.n_seq_tokens;
+    // const int64_t n_s = args.n_seqs;
 
     for (int64_t i2 = 0; i2 < n_t; ++i2) {
-        device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02);
-        device const float * x  = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12);
-        device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22);
-        device const float * A  = (device const float *) ((device const char *) src3 + ir*nb31);
-        device const float * B  = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42);
-        device const float * C  = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52);
-        device       float * y  = (device       float *) ((device       char *) dst  + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides
-        device       float * s  = (device       float *) ((device       char *) dst  + ir*nb01 + i3*nb02 +    nb13);
+        device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb01 + i3*args.nb02);
+        device const float * x  = (device const float *) ((device const char *) src1 + ir*args.nb10 + i2*args.nb11 + i3*args.nb12);
+        device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22);
+        device const float * A  = (device const float *) ((device const char *) src3 + ir*args.nb31);
+        device const float * B  = (device const float *) ((device const char *) src4 + i2*args.nb41 + i3*args.nb42);
+        device const float * C  = (device const float *) ((device const char *) src5 + i2*args.nb51 + i3*args.nb52);
+        device       float * y  = (device       float *) ((device       char *) dst  + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); // TODO: do not use src1 strides
+        device       float * s  = (device       float *) ((device       char *) dst  + ir*args.nb01 + i3*args.nb02 +    args.nb13);
 
         if (i2 > 0) {
             s0 = s;
@@ -1545,22 +1466,15 @@ kernel void kernel_rms_norm(
 kernel void kernel_group_norm(
         device const float * src0,
         device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant   int32_t & n_groups,
-        constant     float & eps,
+        constant ggml_metal_kargs_group_norm & args,
         threadgroup float  * buf [[threadgroup(0)]],
         uint tgpig[[threadgroup_position_in_grid]],
         uint tpitg[[thread_position_in_threadgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint   ntg[[threads_per_threadgroup]]) {
-    const int64_t ne = ne00*ne01*ne02;
-    const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
+    const int64_t ne = args.ne00*args.ne01*args.ne02;
+    const int64_t gs = args.ne00*args.ne01*((args.ne02 + args.n_groups - 1) / args.n_groups);
 
     int start = tgpig * gs;
     int end   = start + gs;
@@ -1624,7 +1538,7 @@ kernel void kernel_group_norm(
     }
 
     const float variance = tmp / gs;
-    const float scale = 1.0f/sqrt(variance + eps);
+    const float scale = 1.0f/sqrt(variance + args.eps);
     for (int j = start; j < end; j += ntg) {
         dst[j] *= scale;
     }
@@ -2588,17 +2502,7 @@ template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_
 typedef void (im2col_t)(
         device const float * x,
         device        char * dst,
-        constant   int32_t & ofs0,
-        constant   int32_t & ofs1,
-        constant   int32_t & IW,
-        constant   int32_t & IH,
-        constant   int32_t & CHW,
-        constant   int32_t & s0,
-        constant   int32_t & s1,
-        constant   int32_t & p0,
-        constant   int32_t & p1,
-        constant   int32_t & d0,
-        constant   int32_t & d1,
+        constant ggml_metal_kargs_im2col & args,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3  tgpg[[threadgroups_per_grid]],
         uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2608,17 +2512,7 @@ template <typename T>
 kernel void kernel_im2col(
         device const float * x,
         device        char * dst,
-        constant   int32_t & ofs0,
-        constant   int32_t & ofs1,
-        constant   int32_t & IW,
-        constant   int32_t & IH,
-        constant   int32_t & CHW,
-        constant   int32_t & s0,
-        constant   int32_t & s1,
-        constant   int32_t & p0,
-        constant   int32_t & p1,
-        constant   int32_t & d0,
-        constant   int32_t & d1,
+        constant ggml_metal_kargs_im2col & args,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3  tgpg[[threadgroups_per_grid]],
         uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2639,17 +2533,17 @@ kernel void kernel_im2col(
     const int64_t ioh = tgpig[1];
     const int64_t iow = tgpig[2];
 
-    const int64_t iiw = iow*s0 + ikw*d0 - p0;
-    const int64_t iih = ioh*s1 + ikh*d1 - p1;
+    const int64_t iiw = iow*args.s0 + ikw*args.d0 - args.p0;
+    const int64_t iih = ioh*args.s1 + ikh*args.d1 - args.p1;
 
-    const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*CHW + (iic*(KH*KW) + ikh*KW + ikw);
+    const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*args.CHW + (iic*(KH*KW) + ikh*KW + ikw);
 
     device T * pdst = (device T *) (dst);
 
-    if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+    if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
         pdst[offset_dst] = 0.0f;
     } else {
-        const int64_t offset_src = in*ofs0 + iic*ofs1 + iih*IW + iiw;
+        const int64_t offset_src = in*args.ofs0 + iic*args.ofs1 + iih*args.IW + iiw;
         pdst[offset_dst] = x[offset_src];
     }
 }
@@ -2660,20 +2554,7 @@ template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
 typedef void (im2col_ext_t)(
         device const float * x,
         device        char * dst,
-        constant   int32_t & ofs0,
-        constant   int32_t & ofs1,
-        constant   int32_t & IW,
-        constant   int32_t & IH,
-        constant   int32_t & CHW,
-        constant   int32_t & s0,
-        constant   int32_t & s1,
-        constant   int32_t & p0,
-        constant   int32_t & p1,
-        constant   int32_t & d0,
-        constant   int32_t & d1,
-        constant   int32_t & N,
-        constant   int32_t & KH,
-        constant   int32_t & KW,
+        constant ggml_metal_kargs_im2col & args,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3  tgpg[[threadgroups_per_grid]],
         uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2683,53 +2564,40 @@ template <typename T>
 kernel void kernel_im2col_ext(
         device const float * x,
         device        char * dst,
-        constant   int32_t & ofs0,
-        constant   int32_t & ofs1,
-        constant   int32_t & IW,
-        constant   int32_t & IH,
-        constant   int32_t & CHW,
-        constant   int32_t & s0,
-        constant   int32_t & s1,
-        constant   int32_t & p0,
-        constant   int32_t & p1,
-        constant   int32_t & d0,
-        constant   int32_t & d1,
-        constant   int32_t & N,
-        constant   int32_t & KH,
-        constant   int32_t & KW,
+        constant ggml_metal_kargs_im2col & args,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3  tgpg[[threadgroups_per_grid]],      // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
         uint3 tpitg[[thread_position_in_threadgroup]],
         uint3   ntg[[threads_per_threadgroup]]) {  // [M, 1, 1]
-    const int64_t KHW = KH * KW;             // KHW == ntg[1] * ntg[2], KW == ntg[2]
+    const int64_t KHW = (int64_t)args.KHW;
 
-    const int64_t d = tgpig[0] / CHW;
-    const int64_t chw = tgpig[0] % CHW;
+    const int64_t d = tgpig[0] / args.CHW;
+    const int64_t chw = tgpig[0] % args.CHW;
     const int64_t tgpig_0 = chw / KHW;  // 0 ~ (IC - 1)
     const int64_t HW = tgpig[0] % KHW;
 
     const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
-    if (tpitg_0 >= N) {
+    if (tpitg_0 >= args.N) {
         return;
     }
 
-    const int64_t tpitg_1 = HW / KW;
-    const int64_t tpitg_2 = HW % KW;
+    const int64_t tpitg_1 = HW / args.KW;
+    const int64_t tpitg_2 = HW % args.KW;
 
-    const int64_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0;
-    const int64_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1;
+    const int64_t iiw = tgpig[2] * args.s0 + tpitg_2 * args.d0 - args.p0;
+    const int64_t iih = tgpig[1] * args.s1 + tpitg_1 * args.d1 - args.p1;
 
     const int64_t offset_dst =
-        (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
-        (tgpig_0 * KHW + tpitg_1 * KW + tpitg_2);
+        (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * args.CHW +
+        (tgpig_0 * KHW + tpitg_1 * args.KW + tpitg_2);
 
     device T * pdst = (device T *) (dst);
 
-    if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+    if (iih < 0 || iih >= args.IH || iiw < 0 || iiw >= args.IW) {
         pdst[offset_dst] = 0.0f;
     } else {
-        const int64_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1;
-        pdst[offset_dst] = x[offset_src + iih * IW + iiw];
+        const int64_t offset_src = tpitg_0 * args.ofs0 + tgpig_0 * args.ofs1;
+        pdst[offset_dst] = x[offset_src + iih * args.IW + iiw];
     }
 }
 
@@ -2740,12 +2608,7 @@ typedef void (conv_transpose_1d_t)(
         device const float * src0,
         device const float * src1,
         device        char * dst,
-        constant   int32_t & IC,
-        constant   int32_t & IL,
-        constant   int32_t & K,
-        constant   int32_t & s0,
-        constant  uint64_t & nb0,
-        constant  uint64_t & nb1,
+        constant ggml_metal_kargs_conv_transpose_1d & args,
         uint3   tgpig[[threadgroup_position_in_grid]],
         uint3    tgpg[[threadgroups_per_grid]]);
 
@@ -2754,29 +2617,24 @@ kernel void kernel_conv_transpose_1d(
         device const     T * src0,
         device const float * src1,
         device        char * dst,
-        constant   int32_t & IC,
-        constant   int32_t & IL,
-        constant   int32_t & K,
-        constant   int32_t & s0,
-        constant  uint64_t & nb0,
-        constant  uint64_t & nb1,
+        constant ggml_metal_kargs_conv_transpose_1d & args,
         uint3   tgpig[[threadgroup_position_in_grid]],
         uint3   tgpg[[threadgroups_per_grid]]) {
 
     float v = 0.0f;
 
-    for (int64_t c = 0; c < IC; c++) {
-        const int32_t kernel_offset = c * tgpg[1] * K + K * tgpig[1];
-        const int32_t input_offset = c * IL;
+    for (int64_t c = 0; c < args.IC; c++) {
+        const int32_t kernel_offset = c * tgpg[1] * args.K + args.K * tgpig[1];
+        const int32_t input_offset = c * args.IL;
 
-        for (int64_t i = 0; i < IL; i++) {
-            if (tgpig[0] >= i * s0 && tgpig[0] < i * s0 + K) {
-                v += src0[kernel_offset + tgpig[0] - i * s0] * src1[input_offset + i];
+        for (int64_t i = 0; i < args.IL; i++) {
+            if (tgpig[0] >= i * args.s0 && tgpig[0] < i * args.s0 + args.K) {
+                v += src0[kernel_offset + tgpig[0] - i * args.s0] * src1[input_offset + i];
             }
         }
     }
 
-    device float * dst_ptr = (device float *) (dst + tgpig[0] * nb0 + tgpig[1] * nb1);
+    device float * dst_ptr = (device float *) (dst + tgpig[0] * args.nb0 + tgpig[1] * args.nb1);
 
     dst_ptr[0] = v;
 }
@@ -2786,12 +2644,7 @@ kernel void kernel_conv_transpose_1d<float>(
     device const float * src0,
     device const float * src1,
     device        char * dst,
-    constant   int32_t & IC,
-    constant   int32_t & IL,
-    constant   int32_t & K,
-    constant   int32_t & s0,
-    constant  uint64_t & nb0,
-    constant  uint64_t & nb1,
+    constant ggml_metal_kargs_conv_transpose_1d & args,
     uint3   tgpig[[threadgroup_position_in_grid]],
     uint3    tgpg[[threadgroups_per_grid]]);
 
@@ -2800,38 +2653,14 @@ kernel void kernel_conv_transpose_1d<half>(
     device const half  * src0,
     device const float * src1,
     device        char * dst,
-    constant   int32_t & IC,
-    constant   int32_t & IL,
-    constant   int32_t & K,
-    constant   int32_t & s0,
-    constant  uint64_t & nb0,
-    constant  uint64_t & nb1,
+    constant ggml_metal_kargs_conv_transpose_1d & args,
     uint3   tgpig[[threadgroup_position_in_grid]],
     uint3    tgpg[[threadgroups_per_grid]]);
 
 kernel void kernel_upscale_f32(
     device  const char * src0,
     device        char * dst,
-    constant   int64_t & ne00,
-    constant   int64_t & ne01,
-    constant   int64_t & ne02,
-    constant   int64_t & ne03,
-    constant  uint64_t & nb00,
-    constant  uint64_t & nb01,
-    constant  uint64_t & nb02,
-    constant  uint64_t & nb03,
-    constant   int64_t & ne0,
-    constant   int64_t & ne1,
-    constant   int64_t & ne2,
-    constant   int64_t & ne3,
-    constant  uint64_t & nb0,
-    constant  uint64_t & nb1,
-    constant  uint64_t & nb2,
-    constant  uint64_t & nb3,
-    constant     float & sf0,
-    constant     float & sf1,
-    constant     float & sf2,
-    constant     float & sf3,
+    constant ggml_metal_kargs_upscale & args,
     uint3 tgpig[[threadgroup_position_in_grid]],
     uint3 tpitg[[thread_position_in_threadgroup]],
     uint3   ntg[[threads_per_threadgroup]]) {
@@ -2840,15 +2669,15 @@ kernel void kernel_upscale_f32(
     const int64_t i2 = tgpig.y;
     const int64_t i1 = tgpig.x;
 
-    const int64_t i03 = i3/sf3;
-    const int64_t i02 = i2/sf2;
-    const int64_t i01 = i1/sf1;
+    const int64_t i03 = i3/args.sf3;
+    const int64_t i02 = i2/args.sf2;
+    const int64_t i01 = i1/args.sf1;
 
-    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
-        const int64_t i00 = i0/sf0;
+    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+        const int64_t i00 = i0/args.sf0;
 
-        device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
-        device       float * dst_ptr  = (device       float *) (dst  +  i3*nb3  +  i2*nb2  +  i1*nb1  +  i0*nb0);
+        device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
+        device       float * dst_ptr  = (device       float *) (dst  +  i3*args.nb3  +  i2*args.nb2  +  i1*args.nb1  +  i0*args.nb0);
 
         dst_ptr[0] = src0_ptr[0];
     }
@@ -2857,22 +2686,7 @@ kernel void kernel_upscale_f32(
 kernel void kernel_pad_f32(
     device  const char * src0,
     device        char * dst,
-    constant   int64_t & ne00,
-    constant   int64_t & ne01,
-    constant   int64_t & ne02,
-    constant   int64_t & ne03,
-    constant  uint64_t & nb00,
-    constant  uint64_t & nb01,
-    constant  uint64_t & nb02,
-    constant  uint64_t & nb03,
-    constant   int64_t & ne0,
-    constant   int64_t & ne1,
-    constant   int64_t & ne2,
-    constant   int64_t & ne3,
-    constant  uint64_t & nb0,
-    constant  uint64_t & nb1,
-    constant  uint64_t & nb2,
-    constant  uint64_t & nb3,
+    constant ggml_metal_kargs_pad & args,
     uint3 tgpig[[threadgroup_position_in_grid]],
     uint3 tpitg[[thread_position_in_threadgroup]],
     uint3   ntg[[threads_per_threadgroup]]) {
@@ -2885,12 +2699,12 @@ kernel void kernel_pad_f32(
     const int64_t i02 = i2;
     const int64_t i01 = i1;
 
-    device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
-    device       float * dst_ptr  = (device       float *) (dst  +  i3*nb3  +  i2*nb2  +  i1*nb1);
+    device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
+    device       float * dst_ptr  = (device       float *) (dst  +  i3*args.nb3  +  i2*args.nb2  +  i1*args.nb1);
 
-    if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
-        for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
-            if (i0 < ne00) {
+    if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
+        for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+            if (i0 < args.ne00) {
                 dst_ptr[i0] = src0_ptr[i0];
             } else {
                 dst_ptr[i0] = 0.0f;
@@ -2900,7 +2714,7 @@ kernel void kernel_pad_f32(
         return;
     }
 
-    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
         dst_ptr[i0] = 0.0f;
     }
 }
@@ -2908,21 +2722,7 @@ kernel void kernel_pad_f32(
 kernel void kernel_pad_reflect_1d_f32(
     device  const char * src0,
     device        char * dst,
-    constant   int64_t & ne00,
-    constant   int64_t & ne01,
-    constant   int64_t & ne02,
-    constant   int64_t & ne03,
-    constant   int64_t & ne0,
-    constant  uint64_t & nb00,
-    constant  uint64_t & nb01,
-    constant  uint64_t & nb02,
-    constant  uint64_t & nb03,
-    constant  uint64_t & nb0,
-    constant  uint64_t & nb1,
-    constant  uint64_t & nb2,
-    constant  uint64_t & nb3,
-    constant   int32_t & p0,
-    constant   int32_t & p1,
+    constant   ggml_metal_kargs_pad_reflect_1d & args,
     uint3 tgpig[[threadgroup_position_in_grid]],
     uint3  tgpg[[threadgroups_per_grid]],
     uint3 tpitg[[thread_position_in_threadgroup]],
@@ -2936,17 +2736,17 @@ kernel void kernel_pad_reflect_1d_f32(
     const int64_t i02 = i2;
     const int64_t i01 = i1;
 
-    device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
-    device       float * dst_ptr  = (device       float *) (dst  +  i3*nb3  +  i2*nb2  +  i1*nb1);
+    device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
+    device       float * dst_ptr  = (device       float *) (dst  +  i3*args.nb3  +  i2*args.nb2  +  i1*args.nb1);
 
-    if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
-        for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
-            if (i0 < p0) {
-                dst_ptr[i0] = src0_ptr[p0 - i0];
-            } else if (i0 < ne0 - p1) {
-                dst_ptr[i0] = src0_ptr[i0 - p0];
+    if (i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
+        for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+            if (i0 < args.p0) {
+                dst_ptr[i0] = src0_ptr[args.p0 - i0];
+            } else if (i0 < args.ne0 - args.p1) {
+                dst_ptr[i0] = src0_ptr[i0 - args.p0];
             } else {
-                dst_ptr[i0] = src0_ptr[(ne0 - p1 - p0) - (p1 + 1 - (ne0 - i0)) - 1];
+                dst_ptr[i0] = src0_ptr[(args.ne0 - args.p1 - args.p0) - (args.p1 + 1 - (args.ne0 - i0)) - 1];
             }
         }
     }
@@ -2954,44 +2754,40 @@ kernel void kernel_pad_reflect_1d_f32(
 
 kernel void kernel_arange_f32(
     device        char * dst,
-    constant   int64_t & ne0,
-    constant   float   & start,
-    constant   float   & step,
+    constant   ggml_metal_kargs_arange & args,
     uint3 tgpig[[threadgroup_position_in_grid]],
     uint3 tpitg[[thread_position_in_threadgroup]],
     uint3   ntg[[threads_per_threadgroup]]) {
 
     device float * dst_ptr = (device float *) dst;
 
-    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
-        dst_ptr[i0] = start + step * i0;
+    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+        dst_ptr[i0] = args.start + args.step * i0;
     }
 }
 
 kernel void kernel_timestep_embedding_f32(
     device  const char * src0,
     device        char * dst,
-    constant  uint64_t & nb1,
-    constant  int      & dim,
-    constant  int      & max_period,
+    constant  ggml_metal_kargs_timestep_embedding & args,
     uint3 tgpig[[threadgroup_position_in_grid]],
     uint3 tpitg[[thread_position_in_threadgroup]],
     uint3   ntg[[threads_per_threadgroup]]) {
 
     int i = tgpig.x;
-    device float * embed_data = (device float *)(dst +  i*nb1);
+    device float * embed_data = (device float *)(dst + i*args.nb1);
 
-    int half_ = dim / 2;
+    int half_ = args.dim / 2;
     for (int j = tpitg.x; j < half_; j += ntg.x) {
         float timestep = ((device float *)src0)[i];
-        float freq = (float)exp(-log((float)max_period) * j / half_);
+        float freq = (float)exp(-log((float)args.max_period) * j / half_);
         float arg = timestep * freq;
         embed_data[j        ] = cos(arg);
         embed_data[j + half_] = sin(arg);
     }
 
-    if (dim % 2 != 0 && tpitg.x == 0) {
-        embed_data[dim] = 0.f;
+    if (args.dim % 2 != 0 && tpitg.x == 0) {
+        embed_data[args.dim] = 0.f;
     }
 }
 
@@ -2999,8 +2795,7 @@ kernel void kernel_timestep_embedding_f32(
 typedef void (argsort_t)(
         device const float  * x,
         device     int32_t  * dst,
-        constant   int64_t  & ncols,
-        constant   int64_t  & ncols_pad,
+        constant   ggml_metal_kargs_argsort & args,
         threadgroup int32_t * shared_values [[threadgroup(0)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3 tpitg[[thread_position_in_threadgroup]]);
@@ -3009,8 +2804,7 @@ template<ggml_sort_order order>
 kernel void kernel_argsort_f32_i32(
         device const float   * x,
         device       int32_t * dst,
-        constant     int64_t & ncols,
-        constant     int64_t & ncols_pad,
+        constant   ggml_metal_kargs_argsort & args,
         threadgroup int32_t  * shared_values [[threadgroup(0)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3 tpitg[[thread_position_in_threadgroup]]) {
@@ -3018,9 +2812,9 @@ kernel void kernel_argsort_f32_i32(
     int col = tpitg[0];
     int row = tgpig[1];
 
-    if (col >= ncols_pad) return;
+    if (col >= args.ncols_pad) return;
 
-    device const float   * x_row   = x + row * ncols;
+    device const float   * x_row   = x + row * args.ncols;
     threadgroup int32_t  * dst_row = shared_values;
 
     // initialize indices
@@ -3028,21 +2822,21 @@ kernel void kernel_argsort_f32_i32(
 
     threadgroup_barrier(mem_flags::mem_threadgroup);
 
-    for (int k = 2; k <= ncols_pad; k *= 2) {
+    for (int k = 2; k <= args.ncols_pad; k *= 2) {
         for (int j = k / 2; j > 0; j /= 2) {
             int ixj = col ^ j;
             if (ixj > col) {
                 if ((col & k) == 0) {
-                    if (dst_row[col] >= ncols ||
-                        (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
+                    if (dst_row[col] >= args.ncols ||
+                        (dst_row[ixj] < args.ncols && (order == GGML_SORT_ORDER_ASC ?
                             x_row[dst_row[col]] > x_row[dst_row[ixj]] :
                             x_row[dst_row[col]] < x_row[dst_row[ixj]]))
                     ) {
                         SWAP(dst_row[col], dst_row[ixj]);
                     }
                 } else {
-                    if (dst_row[ixj] >= ncols ||
-                        (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
+                    if (dst_row[ixj] >= args.ncols ||
+                        (dst_row[col] < args.ncols && (order == GGML_SORT_ORDER_ASC ?
                             x_row[dst_row[col]] < x_row[dst_row[ixj]] :
                             x_row[dst_row[col]] > x_row[dst_row[ixj]]))
                     ) {
@@ -3055,8 +2849,8 @@ kernel void kernel_argsort_f32_i32(
     }
 
     // copy the result to dst without the padding
-    if (col < ncols) {
-        dst[row * ncols + col] = dst_row[col];
+    if (col < args.ncols) {
+        dst[row * args.ncols + col] = dst_row[col];
     }
 }
 
@@ -3066,9 +2860,9 @@ template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_ar
 kernel void kernel_leaky_relu_f32(
         device const float * src0,
         device       float * dst,
-        constant     float & slope,
+        constant     ggml_metal_kargs_leaky_relu & args,
         uint tpig[[thread_position_in_grid]]) {
-    dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
+    dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * args.slope;
 }
 
 // ref: https://arxiv.org/pdf/2307.08691.pdf
@@ -6009,28 +5803,21 @@ kernel void kernel_get_rows_q(
         device const  void * src0,
         device const  void * src1,
         device       float * dst,
-        constant   int64_t & ne00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant   int64_t & ne10,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb1,
-        constant  uint64_t & nb2,
+        constant ggml_metal_kargs_get_rows & args,
         uint3                tgpig[[threadgroup_position_in_grid]],
         uint                 tiitg[[thread_index_in_threadgroup]],
         uint3                tptg [[threads_per_threadgroup]]) {
     const int64_t i10 = tgpig.x;
     const int64_t i11 = tgpig.y;
 
-    const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
+    const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
 
     const int64_t i02 = i11;
 
-    for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
+    for (int64_t ind = tiitg; ind < args.ne00/16; ind += tptg.x) {
         float4x4 temp;
-        dequantize_func(((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
-        *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
+        dequantize_func(((device const block_q *) ((const device char *) src0 + r*args.nb01 + i02*args.nb02)) + ind/nl, ind%nl, temp);
+        *(((device float4x4 *) ((device char *) dst + i11*args.nb2 + i10*args.nb1)) + ind) = temp;
     }
 }
 
@@ -6039,27 +5826,20 @@ kernel void kernel_get_rows_f(
         device const  void * src0,
         device const  void * src1,
         device       float * dst,
-        constant   int64_t & ne00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant   int64_t & ne10,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb1,
-        constant  uint64_t & nb2,
+        constant ggml_metal_kargs_get_rows & args,
         uint3                tgpig[[threadgroup_position_in_grid]],
         uint                 tiitg[[thread_index_in_threadgroup]],
         uint3                tptg [[threads_per_threadgroup]]) {
     const int64_t i10 = tgpig.x;
     const int64_t i11 = tgpig.y;
 
-    const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
+    const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
 
     const int64_t i02 = i11;
 
-    for (int ind = tiitg; ind < ne00; ind += tptg.x) {
-        ((      device float *) ((      device char *)  dst + i11*nb2  + i10*nb1))[ind] =
-        ((const device T     *) ((const device char *) src0 + i02*nb02 +  r*nb01))[ind];
+    for (int ind = tiitg; ind < args.ne00; ind += tptg.x) {
+        ((      device float *) ((      device char *)  dst + i11*args.nb2  + i10*args.nb1))[ind] =
+        ((const device T     *) ((const device char *) src0 + i02*args.nb02 +  r*args.nb01))[ind];
     }
 }
 
@@ -6067,27 +5847,20 @@ kernel void kernel_get_rows_i32(
         device const  void * src0,
         device const  void * src1,
         device     int32_t * dst,
-        constant   int64_t & ne00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant   int64_t & ne10,
-        constant  uint64_t & nb10,
-        constant  uint64_t & nb11,
-        constant  uint64_t & nb1,
-        constant  uint64_t & nb2,
+        constant ggml_metal_kargs_get_rows & args,
         uint3                tgpig[[threadgroup_position_in_grid]],
         uint                 tiitg[[thread_index_in_threadgroup]],
         uint3                tptg [[threads_per_threadgroup]]) {
     const int64_t i10 = tgpig.x;
     const int64_t i11 = tgpig.y;
 
-    const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
+    const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
 
     const int64_t i02 = i11;
 
-    for (int ind = tiitg; ind < ne00; ind += tptg.x) {
-        ((      device int32_t *) ((      device char *) dst  + i11*nb2 + i10*nb1))[ind] =
-        ((const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
+    for (int ind = tiitg; ind < args.ne00; ind += tptg.x) {
+        ((      device int32_t *) ((      device char *) dst  + i11*args.nb2 + i10*args.nb1))[ind] =
+        ((const device int32_t *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind];
     }
 }
 
@@ -6689,98 +6462,78 @@ template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]]  kernel kernel_mul_mv_id_t
 kernel void kernel_pool_2d_max_f32(
         device  const float * src0,
         device        float * dst,
-        constant    int32_t & k0,
-        constant    int32_t & k1,
-        constant    int32_t & s0,
-        constant    int32_t & s1,
-        constant    int32_t & p0,
-        constant    int32_t & p1,
-        constant    int64_t & IH,
-        constant    int64_t & IW,
-        constant    int64_t & OH,
-        constant    int64_t & OW,
-        constant    int64_t & parallel_elements,
+        constant    ggml_metal_kargs_pool_2d & args,
         uint        gid[[thread_position_in_grid]]) {
 
-    if (gid >= parallel_elements) {
+    if (gid >= args.parallel_elements) {
         return;
     }
 
     const int idx = gid;
-    const int I_HW = IH * IW;
-    const int O_HW = OH * OW;
+    const int I_HW = args.IH * args.IW;
+    const int O_HW = args.OH * args.OW;
     const int nc = idx / O_HW;
-    const int cur_oh = idx % O_HW / OW;
-    const int cur_ow = idx % O_HW % OW;
+    const int cur_oh = idx % O_HW / args.OW;
+    const int cur_ow = idx % O_HW % args.OW;
 
     device const float * i_ptr = src0 + nc * I_HW;
     device       float * o_ptr = dst  + nc * O_HW;
 
-    const int start_h = cur_oh * s1 - p1;
+    const int start_h = cur_oh * args.s1 - args.p1;
     const int bh = MAX(0,  start_h);
-    const int eh = MIN(IH, start_h + k1);
-    const int start_w = cur_ow * s0 - p0;
+    const int eh = MIN(args.IH, start_h + args.k1);
+    const int start_w = cur_ow * args.s0 - args.p0;
     const int bw = MAX(0,  start_w);
-    const int ew = MIN(IW, start_w + k0);
+    const int ew = MIN(args.IW, start_w + args.k0);
 
     float res = -INFINITY;
 
     for (int i = bh; i < eh; i += 1) {
         for (int j = bw; j < ew; j += 1) {
-            res = MAX(res, i_ptr[i * IW + j]);
+            res = MAX(res, i_ptr[i * args.IW + j]);
         }
     }
 
-    o_ptr[cur_oh * OW + cur_ow] = res;
+    o_ptr[cur_oh * args.OW + cur_ow] = res;
 }
 
 kernel void kernel_pool_2d_avg_f32(
         device  const float * src0,
         device        float * dst,
-        constant    int32_t & k0,
-        constant    int32_t & k1,
-        constant    int32_t & s0,
-        constant    int32_t & s1,
-        constant    int32_t & p0,
-        constant    int32_t & p1,
-        constant    int64_t & IH,
-        constant    int64_t & IW,
-        constant    int64_t & OH,
-        constant    int64_t & OW,
-        constant    int64_t & parallel_elements,
+        constant    ggml_metal_kargs_pool_2d & args,
         uint        gid[[thread_position_in_grid]]) {
 
-    if (gid >= parallel_elements) {
+    if (gid >= args.parallel_elements) {
         return;
     }
 
     const int idx = gid;
-    const int I_HW = IH * IW;
-    const int O_HW = OH * OW;
+    const int I_HW = args.IH * args.IW;
+    const int O_HW = args.OH * args.OW;
     const int nc = idx / O_HW;
-    const int cur_oh = idx % O_HW / OW;
-    const int cur_ow = idx % O_HW % OW;
+    const int cur_oh = idx % O_HW / args.OW;
+    const int cur_ow = idx % O_HW % args.OW;
 
     device const float * i_ptr = src0 + nc * I_HW;
     device       float * o_ptr = dst  + nc * O_HW;
 
-    const int start_h = cur_oh * s1 - p1;
+    const int start_h = cur_oh * args.s1 - args.p1;
     const int bh = MAX(0,  start_h);
-    const int eh = MIN(IH, start_h + k1);
-    const int start_w = cur_ow * s0 - p0;
+    const int eh = MIN(args.IH, start_h + args.k1);
+    const int start_w = cur_ow * args.s0 - args.p0;
     const int bw = MAX(0,  start_w);
-    const int ew = MIN(IW, start_w + k0);
+    const int ew = MIN(args.IW, start_w + args.k0);
     // const float scale = 1. / ((eh - bh) * (ew - bw));
-    const float scale = 1. / (k0 * k1);
+    const float scale = 1. / (args.k0 * args.k1);
 
     float res = 0;
 
     for (int i = bh; i < eh; i += 1) {
         for (int j = bw; j < ew; j += 1) {
-            float cur = i_ptr[i * IW + j];
+            float cur = i_ptr[i * args.IW + j];
             res += cur * scale;
         }
     }
 
-    o_ptr[cur_oh * OW + cur_ow] = res;
+    o_ptr[cur_oh * args.OW + cur_ow] = res;
 }