]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
metal : use params per pipeline instance (llama/17739)
authorGeorgi Gerganov <redacted>
Thu, 4 Dec 2025 08:34:11 +0000 (10:34 +0200)
committerGeorgi Gerganov <redacted>
Fri, 12 Dec 2025 15:53:16 +0000 (17:53 +0200)
ggml/src/ggml-metal/ggml-metal-device.cpp
ggml/src/ggml-metal/ggml-metal-device.h
ggml/src/ggml-metal/ggml-metal-device.m
ggml/src/ggml-metal/ggml-metal-ops.cpp

index c647baef8780785e6a96042ab51d18eeff737a9d..33ab43d58f50c0ad9f0f90d0de670e8eb340b845 100644 (file)
@@ -57,7 +57,7 @@ ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, cons
     return ppls->data[name];
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_base(ggml_metal_library_t lib, ggml_op op) {
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base(ggml_metal_library_t lib, ggml_op op) {
     char base[256];
     char name[256];
 
@@ -71,34 +71,30 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_base(ggml_metal_library_t
     snprintf(base, 256, "kernel_%s", op_str);
     snprintf(name, 256, "%s", base);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cpy(ggml_metal_library_t lib, ggml_type tsrc, ggml_type tdst) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy(ggml_metal_library_t lib, ggml_type tsrc, ggml_type tdst) {
     char base[256];
     char name[256];
 
     snprintf(base, 256, "kernel_cpy_%s_%s", ggml_type_name(tsrc), ggml_type_name(tdst));
     snprintf(name, 256, "%s", base);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library_t lib, const ggml_tensor * op, ggml_op_pool op_pool) {
     GGML_ASSERT(ggml_is_contiguous(op->src[0]));
     GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32 && op->src[0]->type == op->type);
 
@@ -115,68 +111,60 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d(ggml_metal_library
     snprintf(base, 256, "kernel_pool_2d_%s_%s", pool_str, ggml_type_name(op->src[0]->type));
     snprintf(name, 256, "%s", base);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_get_rows(ggml_metal_library_t lib, ggml_type tsrc) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows(ggml_metal_library_t lib, ggml_type tsrc) {
     char base[256];
     char name[256];
 
     snprintf(base, 256, "kernel_get_rows_%s", ggml_type_name(tsrc));
     snprintf(name, 256, "%s", base);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows(ggml_metal_library_t lib, ggml_type tidx, ggml_type tdst) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows(ggml_metal_library_t lib, ggml_type tidx, ggml_type tdst) {
     char base[256];
     char name[256];
 
     snprintf(base, 256, "kernel_set_rows_%s_%s", ggml_type_name(tdst), ggml_type_name(tidx));
     snprintf(name, 256, "%s", base);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat(ggml_metal_library_t lib, ggml_type tsrc) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_metal_library_t lib, ggml_type tsrc) {
     char base[256];
     char name[256];
 
     snprintf(base, 256, "kernel_repeat_%s", ggml_type_name(tsrc));
     snprintf(name, 256, "%s", base);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) {
     GGML_ASSERT(ggml_is_contiguous(op->src[0]));
 
     char base[256];
@@ -224,17 +212,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary(ggml_metal_library_t
     snprintf(base, 256, "kernel_%s_%s%s", op_str, ggml_type_name(op->src[0]->type), suffix);
     snprintf(name, 256, "%s", base);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu(ggml_metal_library_t lib, const ggml_tensor * op) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu(ggml_metal_library_t lib, const ggml_tensor * op) {
     GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
 
     char base[256];
@@ -258,17 +244,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu(ggml_metal_library_t l
     snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type));
     snprintf(name, 256, "%s", base);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum(ggml_metal_library_t lib, const ggml_tensor * op) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum(ggml_metal_library_t lib, const ggml_tensor * op) {
     assert(op->op == GGML_OP_SUM);
 
     char base[256];
@@ -277,17 +261,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum(ggml_metal_library_t l
     snprintf(base, 256, "kernel_op_sum_%s", ggml_type_name(op->src[0]->type));
     snprintf(name, 256, "%s", base);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
     GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
 
     char base[256];
@@ -306,19 +288,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_librar
 
     snprintf(name, 256, "%s", base);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
-    ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
+    res.smem = 32*sizeof(float);
 
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_blk(ggml_metal_library_t lib, const ggml_tensor * op) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_blk(ggml_metal_library_t lib, const ggml_tensor * op) {
     GGML_ASSERT(op->op == GGML_OP_CUMSUM);
 
     char base[256];
@@ -327,17 +307,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_blk(ggml_metal_libr
     snprintf(base, 256, "kernel_cumsum_blk_%s", ggml_type_name(op->src[0]->type));
     snprintf(name, 256, "%s", base);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add(ggml_metal_library_t lib, const ggml_tensor * op) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add(ggml_metal_library_t lib, const ggml_tensor * op) {
     GGML_ASSERT(op->op == GGML_OP_CUMSUM);
 
     char base[256];
@@ -346,17 +324,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add(ggml_metal_libr
     snprintf(base, 256, "kernel_cumsum_add_%s", ggml_type_name(op->src[0]->type));
     snprintf(name, 256, "%s", base);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) {
     GGML_ASSERT(!op->src[1] || op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32);
 
     char base[256];
@@ -373,19 +349,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max(ggml_metal_librar
     snprintf(base, 256, "kernel_soft_max_%s%s", ggml_type_name(tsrc1), suffix);
     snprintf(name, 256, "%s", base);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
-    ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
+    res.smem = 32*sizeof(float);
 
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_library_t lib, const ggml_tensor * op) {
     GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
     GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
 
@@ -404,17 +378,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar
     snprintf(base, 256, "kernel_ssm_conv_%s_%s%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix);
     snprintf(name, 256, "%s", base);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op)  {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op)  {
     GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
 
     char base[256];
@@ -425,19 +397,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_librar
     snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type));
     snprintf(name, 256, "%s_nsg=%d", base, nsg);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
-    ggml_metal_pipeline_set_smem(res, 32*sizeof(float)*nsg);
+    res.smem = 32*sizeof(float)*nsg;
 
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t lib, const ggml_tensor * op) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t lib, const ggml_tensor * op) {
     char base[256];
     char name[256];
 
@@ -467,41 +437,37 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t
 
     snprintf(name, 256, "%s", base);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) {
     char base[256];
     char name[256];
 
     snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg);
     snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
-    }
-
-    ggml_metal_cv_t cv = ggml_metal_cv_init();
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        ggml_metal_cv_t cv = ggml_metal_cv_init();
 
-    ggml_metal_cv_set_int16(cv, nsg,   FC_MUL_MV + 0);
-    ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1);
+        ggml_metal_cv_set_int16(cv, nsg,   FC_MUL_MV + 0);
+        ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1);
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
 
-    ggml_metal_cv_free(cv);
+        ggml_metal_cv_free(cv);
+    }
 
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, const ggml_tensor * op) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, const ggml_tensor * op) {
     char base[256];
     char name[256];
 
@@ -514,27 +480,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_
     snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
     snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
-    }
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        ggml_metal_cv_t cv = ggml_metal_cv_init();
 
-    ggml_metal_cv_t cv = ggml_metal_cv_init();
+        ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
+        ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1);
 
-    ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
-    ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1);
+        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
-
-    ggml_metal_cv_free(cv);
+        ggml_metal_cv_free(cv);
+    }
 
     // when the output size is not multiple of 64x32, we need extra smem to prevent out-of-bounds writes
-    ggml_metal_pipeline_set_smem(res, bc_out ? 8192 : 4096 + 2048);
+    res.smem = bc_out ? 8192 : 4096 + 2048;
 
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_t lib, const ggml_tensor * op) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_t lib, const ggml_tensor * op) {
     GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
     GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
 
@@ -689,49 +653,43 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
     snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
     snprintf(name, 256, "%s_nsg=%d", base, nsg);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
-    }
-
-    ggml_metal_cv_t cv = ggml_metal_cv_init();
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        ggml_metal_cv_t cv = ggml_metal_cv_init();
 
-    ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
+        ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
 
-    ggml_metal_cv_free(cv);
+        ggml_metal_cv_free(cv);
+    }
 
-    ggml_metal_pipeline_set_nr0 (res, nr0);
-    ggml_metal_pipeline_set_nr1 (res, nr1);
-    ggml_metal_pipeline_set_nsg (res, nsg);
-    ggml_metal_pipeline_set_smem(res, smem);
+    res.nr0  = nr0;
+    res.nr1  = nr1;
+    res.nsg  = nsg;
+    res.smem = smem;
 
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_library_t lib, int ne02, int ne20) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_library_t lib, int ne02, int ne20) {
     char base[256];
     char name[256];
 
     snprintf(base, 256, "kernel_mul_mm_id_map0_ne20_%d", ne20);
     snprintf(name, 256, "%s_ne02=%d", base, ne02);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
-    const size_t smem = (size_t) ne02*ne20*sizeof(uint16_t);
-
-    ggml_metal_pipeline_set_smem(res, smem);
+    res.smem = (size_t) ne02*ne20*sizeof(uint16_t);
 
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, const ggml_tensor * op) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, const ggml_tensor * op) {
     char base[256];
     char name[256];
 
@@ -743,25 +701,23 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_libra
     snprintf(base, 256, "kernel_mul_mm_id_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
     snprintf(name, 256, "%s_bci=%d", base, bc_inp);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
-    }
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        ggml_metal_cv_t cv = ggml_metal_cv_init();
 
-    ggml_metal_cv_t cv = ggml_metal_cv_init();
+        ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
 
-    ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
+        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
-
-    ggml_metal_cv_free(cv);
+        ggml_metal_cv_free(cv);
+    }
 
-    ggml_metal_pipeline_set_smem(res, 8192);
+    res.smem = 8192;
 
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_library_t lib, const ggml_tensor * op) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_library_t lib, const ggml_tensor * op) {
     GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
     GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
 
@@ -909,28 +865,26 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
     snprintf(base, 256, "kernel_mul_mv_id_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
     snprintf(name, 256, "%s_nsg=%d", base, nsg);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
-    }
-
-    ggml_metal_cv_t cv = ggml_metal_cv_init();
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        ggml_metal_cv_t cv = ggml_metal_cv_init();
 
-    ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
+        ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
 
-    ggml_metal_cv_free(cv);
+        ggml_metal_cv_free(cv);
+    }
 
-    ggml_metal_pipeline_set_nr0 (res, nr0);
-    ggml_metal_pipeline_set_nr1 (res, nr1);
-    ggml_metal_pipeline_set_nsg (res, nsg);
-    ggml_metal_pipeline_set_smem(res, smem);
+    res.nr0  = nr0;
+    res.nr1  = nr1;
+    res.nsg  = nsg;
+    res.smem = smem;
 
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax(ggml_metal_library_t lib, const ggml_tensor * op) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argmax(ggml_metal_library_t lib, const ggml_tensor * op) {
     GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32);
     GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
     GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
@@ -941,19 +895,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax(ggml_metal_library_
     snprintf(base, 256, "kernel_argmax_%s", ggml_type_name(op->src[0]->type));
     snprintf(name, 256, "%s", base);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
-    ggml_metal_pipeline_set_smem(res, 32*(sizeof(float) + sizeof(int32_t)));
+    res.smem = 32*(sizeof(float) + sizeof(int32_t));
 
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library_t lib, const ggml_tensor * op) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort(ggml_metal_library_t lib, const ggml_tensor * op) {
     assert(op->op == GGML_OP_ARGSORT);
 
     char base[256];
@@ -971,17 +923,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library
     snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
     snprintf(name, 256, "%s", base);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_library_t lib, const ggml_tensor * op) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_library_t lib, const ggml_tensor * op) {
     assert(op->op == GGML_OP_ARGSORT);
 
     char base[256];
@@ -999,18 +949,16 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge(ggml_metal_l
     snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
     snprintf(name, 256, "%s", base);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
     return res;
 }
 
 // note: reuse the argsort kernel for top_k
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t lib, const ggml_tensor * op) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t lib, const ggml_tensor * op) {
     assert(op->op == GGML_OP_TOP_K);
 
     char base[256];
@@ -1029,17 +977,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k(ggml_metal_library_t
     snprintf(base, 256, "kernel_argsort_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
     snprintf(name, 256, "%s", base);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_library_t lib, const ggml_tensor * op) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_library_t lib, const ggml_tensor * op) {
     assert(op->op == GGML_OP_TOP_K);
 
     char base[256];
@@ -1057,17 +1003,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge(ggml_metal_lib
     snprintf(base, 256, "kernel_argsort_merge_%s_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->type), order_str);
     snprintf(name, 256, "%s", base);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad(
         ggml_metal_library_t lib,
         const struct ggml_tensor * op,
         bool    has_mask,
@@ -1086,33 +1030,31 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
             has_mask,
             ncpsg);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
-    }
-
-    ggml_metal_cv_t cv = ggml_metal_cv_init();
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        ggml_metal_cv_t cv = ggml_metal_cv_init();
 
-    ggml_metal_cv_set_bool(cv, has_mask,  FC_FLASH_ATTN_EXT_PAD + 0);
-  //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
-  //ggml_metal_cv_set_bool(cv, has_bias,  FC_FLASH_ATTN_EXT_PAD + 2);
-  //ggml_metal_cv_set_bool(cv, has_scap,  FC_FLASH_ATTN_EXT_PAD + 3);
+        ggml_metal_cv_set_bool(cv, has_mask,  FC_FLASH_ATTN_EXT_PAD + 0);
+        //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1);
+        //ggml_metal_cv_set_bool(cv, has_bias,  FC_FLASH_ATTN_EXT_PAD + 2);
+        //ggml_metal_cv_set_bool(cv, has_scap,  FC_FLASH_ATTN_EXT_PAD + 3);
 
-  //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
-  //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
-  //ggml_metal_cv_set_int32(cv, nsg,  FC_FLASH_ATTN_EXT_PAD + 22);
-  //ggml_metal_cv_set_int32(cv, nwg,  FC_FLASH_ATTN_EXT_PAD + 23);
-  //ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24);
-    ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25);
+        //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20);
+        //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21);
+        //ggml_metal_cv_set_int32(cv, nsg,  FC_FLASH_ATTN_EXT_PAD + 22);
+        //ggml_metal_cv_set_int32(cv, nwg,  FC_FLASH_ATTN_EXT_PAD + 23);
+        //ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24);
+        ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25);
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
 
-    ggml_metal_cv_free(cv);
+        ggml_metal_cv_free(cv);
+    }
 
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk(
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_blk(
         ggml_metal_library_t lib,
         const struct ggml_tensor * op,
         int32_t nqptg,
@@ -1131,33 +1073,31 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk(
             nqptg,
             ncpsg);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
-    }
-
-    ggml_metal_cv_t cv = ggml_metal_cv_init();
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        ggml_metal_cv_t cv = ggml_metal_cv_init();
 
-  //ggml_metal_cv_set_bool(cv, has_mask,  FC_FLASH_ATTN_EXT_BLK + 0);
-  //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1);
-  //ggml_metal_cv_set_bool(cv, has_bias,  FC_FLASH_ATTN_EXT_BLK + 2);
-  //ggml_metal_cv_set_bool(cv, has_scap,  FC_FLASH_ATTN_EXT_BLK + 3);
+        //ggml_metal_cv_set_bool(cv, has_mask,  FC_FLASH_ATTN_EXT_BLK + 0);
+        //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1);
+        //ggml_metal_cv_set_bool(cv, has_bias,  FC_FLASH_ATTN_EXT_BLK + 2);
+        //ggml_metal_cv_set_bool(cv, has_scap,  FC_FLASH_ATTN_EXT_BLK + 3);
 
-  //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20);
-  //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21);
-  //ggml_metal_cv_set_int32(cv, nsg,  FC_FLASH_ATTN_EXT_BLK + 22);
-  //ggml_metal_cv_set_int32(cv, nwg,  FC_FLASH_ATTN_EXT_BLK + 23);
-    ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24);
-    ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25);
+        //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20);
+        //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21);
+        //ggml_metal_cv_set_int32(cv, nsg,  FC_FLASH_ATTN_EXT_BLK + 22);
+        //ggml_metal_cv_set_int32(cv, nwg,  FC_FLASH_ATTN_EXT_BLK + 23);
+        ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24);
+        ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25);
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
 
-    ggml_metal_cv_free(cv);
+        ggml_metal_cv_free(cv);
+    }
 
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext(
         ggml_metal_library_t lib,
         const ggml_tensor * op,
         bool    has_mask,
@@ -1198,33 +1138,31 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
             ns20,
             nsg);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
-    }
-
-    ggml_metal_cv_t cv = ggml_metal_cv_init();
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        ggml_metal_cv_t cv = ggml_metal_cv_init();
 
-    ggml_metal_cv_set_bool(cv, has_mask,  FC_FLASH_ATTN_EXT + 0);
-    ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
-    ggml_metal_cv_set_bool(cv, has_bias,  FC_FLASH_ATTN_EXT + 2);
-    ggml_metal_cv_set_bool(cv, has_scap,  FC_FLASH_ATTN_EXT + 3);
-    ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4);
+        ggml_metal_cv_set_bool(cv, has_mask,  FC_FLASH_ATTN_EXT + 0);
+        ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1);
+        ggml_metal_cv_set_bool(cv, has_bias,  FC_FLASH_ATTN_EXT + 2);
+        ggml_metal_cv_set_bool(cv, has_scap,  FC_FLASH_ATTN_EXT + 3);
+        ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4);
 
-    ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10);
+        ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10);
 
-    ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
-    ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
-    ggml_metal_cv_set_int32(cv, nsg,  FC_FLASH_ATTN_EXT + 22);
+        ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20);
+        ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21);
+        ggml_metal_cv_set_int32(cv, nsg,  FC_FLASH_ATTN_EXT + 22);
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
 
-    ggml_metal_cv_free(cv);
+        ggml_metal_cv_free(cv);
+    }
 
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec(
         ggml_metal_library_t lib,
         const ggml_tensor * op,
         bool    has_mask,
@@ -1262,32 +1200,30 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
             ns20,
             nsg, nwg);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
-    }
-
-    ggml_metal_cv_t cv = ggml_metal_cv_init();
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        ggml_metal_cv_t cv = ggml_metal_cv_init();
 
-    ggml_metal_cv_set_bool(cv, has_mask,  FC_FLASH_ATTN_EXT_VEC + 0);
-    ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1);
-    ggml_metal_cv_set_bool(cv, has_bias,  FC_FLASH_ATTN_EXT_VEC + 2);
-    ggml_metal_cv_set_bool(cv, has_scap,  FC_FLASH_ATTN_EXT_VEC + 3);
-    ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4);
+        ggml_metal_cv_set_bool(cv, has_mask,  FC_FLASH_ATTN_EXT_VEC + 0);
+        ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1);
+        ggml_metal_cv_set_bool(cv, has_bias,  FC_FLASH_ATTN_EXT_VEC + 2);
+        ggml_metal_cv_set_bool(cv, has_scap,  FC_FLASH_ATTN_EXT_VEC + 3);
+        ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4);
 
-    ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20);
-    ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21);
-    ggml_metal_cv_set_int32(cv, nsg,  FC_FLASH_ATTN_EXT_VEC + 22);
-    ggml_metal_cv_set_int32(cv, nwg,  FC_FLASH_ATTN_EXT_VEC + 23);
+        ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20);
+        ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21);
+        ggml_metal_cv_set_int32(cv, nsg,  FC_FLASH_ATTN_EXT_VEC + 22);
+        ggml_metal_cv_set_int32(cv, nwg,  FC_FLASH_ATTN_EXT_VEC + 23);
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
 
-    ggml_metal_cv_free(cv);
+        ggml_metal_cv_free(cv);
+    }
 
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
         ggml_metal_library_t lib,
         const ggml_tensor * op,
         int32_t dv,
@@ -1300,26 +1236,24 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
     snprintf(base, 256, "kernel_flash_attn_ext_vec_reduce");
     snprintf(name, 256, "%s_dv=%d_nwg=%d", base, dv, nwg);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
-    }
-
-    ggml_metal_cv_t cv = ggml_metal_cv_init();
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        ggml_metal_cv_t cv = ggml_metal_cv_init();
 
-    ggml_metal_cv_set_int32(cv, dv,  FC_FLASH_ATTN_EXT_VEC_REDUCE + 0);
-    ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC_REDUCE + 1);
+        ggml_metal_cv_set_int32(cv, dv,  FC_FLASH_ATTN_EXT_VEC_REDUCE + 0);
+        ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_VEC_REDUCE + 1);
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
 
-    ggml_metal_cv_free(cv);
+        ggml_metal_cv_free(cv);
+    }
 
     return res;
 
     GGML_UNUSED(op);
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin(
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(
         ggml_metal_library_t lib,
         ggml_op op,
         int32_t n_fuse,
@@ -1344,17 +1278,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin(
 
     snprintf(name, 256, "%s", base);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
     assert(op->op == GGML_OP_L2_NORM);
 
     GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
@@ -1366,19 +1298,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library
     snprintf(base, 256, "kernel_l2_norm_f32");
     snprintf(name, 256, "%s", base);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
-    ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
+    res.smem = 32*sizeof(float);
 
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
     assert(op->op == GGML_OP_GROUP_NORM);
 
     GGML_ASSERT(ggml_is_contiguous(op->src[0]));
@@ -1389,19 +1319,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm(ggml_metal_libr
     snprintf(base, 256, "kernel_group_norm_f32");
     snprintf(name, 256, "%s", base);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
-    ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
+    res.smem = 32*sizeof(float);
 
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op, int n_fuse) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op, int n_fuse) {
     assert(op->op == GGML_OP_NORM || op->op == GGML_OP_RMS_NORM);
 
     GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
@@ -1434,19 +1362,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm(ggml_metal_library_t
 
     snprintf(name, 256, "%s", base);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
-    ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
+    res.smem = 32*sizeof(float);
 
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t lib, const ggml_tensor * op) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope(ggml_metal_library_t lib, const ggml_tensor * op) {
     assert(op->op == GGML_OP_ROPE);
 
     char base[256];
@@ -1473,23 +1399,21 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope(ggml_metal_library_t
 
     snprintf(name, 256, "%s_imrope=%d", base, is_imrope ? 1 : 0);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
-    }
-
-    ggml_metal_cv_t cv = ggml_metal_cv_init();
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        ggml_metal_cv_t cv = ggml_metal_cv_init();
 
-    ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0);
+        ggml_metal_cv_set_bool(cv, is_imrope, FC_ROPE + 0);
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
+        res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
 
-    ggml_metal_cv_free(cv);
+        ggml_metal_cv_free(cv);
+    }
 
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col(ggml_metal_library_t lib, const ggml_tensor * op) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col(ggml_metal_library_t lib, const ggml_tensor * op) {
     assert(op->op == GGML_OP_IM2COL);
 
     GGML_ASSERT(ggml_is_contiguous(op->src[1]));
@@ -1502,17 +1426,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col(ggml_metal_library_
     snprintf(base, 256, "kernel_im2col_%s", ggml_type_name(op->type));
     snprintf(name, 256, "%s", base);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_metal_library_t lib, const ggml_tensor * op) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_metal_library_t lib, const ggml_tensor * op) {
     assert(op->op == GGML_OP_CONV_TRANSPOSE_1D);
 
     GGML_ASSERT(ggml_is_contiguous(op->src[0]));
@@ -1527,17 +1449,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d(ggml_met
     snprintf(base, 256, "kernel_conv_transpose_1d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
     snprintf(name, 256, "%s", base);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_metal_library_t lib, const ggml_tensor * op) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_metal_library_t lib, const ggml_tensor * op) {
     assert(op->op == GGML_OP_CONV_TRANSPOSE_2D);
 
     GGML_ASSERT(ggml_is_contiguous(op->src[0]));
@@ -1552,17 +1472,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d(ggml_met
     snprintf(base, 256, "kernel_conv_transpose_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
     snprintf(name, 256, "%s", base);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_2d(ggml_metal_library_t lib, const ggml_tensor * op) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d(ggml_metal_library_t lib, const ggml_tensor * op) {
     assert(op->op == GGML_OP_CONV_2D);
 
     GGML_ASSERT(ggml_is_contiguous(op->src[0]));
@@ -1576,17 +1494,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_2d(ggml_metal_library
     snprintf(base, 256, "kernel_conv_2d_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type));
     snprintf(name, 256, "%s", base);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale(ggml_metal_library_t lib, const ggml_tensor * op) {
     assert(op->op == GGML_OP_UPSCALE);
 
     char base[256];
@@ -1595,17 +1511,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale(ggml_metal_library
     snprintf(base, 256, "kernel_upscale_%s", ggml_type_name(op->src[0]->type));
     snprintf(name, 256, "%s", base);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad(ggml_metal_library_t lib, const ggml_tensor * op) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad(ggml_metal_library_t lib, const ggml_tensor * op) {
     assert(op->op == GGML_OP_PAD);
 
     char base[256];
@@ -1614,8 +1528,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad(ggml_metal_library_t l
     snprintf(base, 256, "kernel_pad_%s", ggml_type_name(op->src[0]->type));
     snprintf(name, 256, "%s", base);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (res.pipeline) {
         return res;
     }
 
@@ -1624,7 +1538,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad(ggml_metal_library_t l
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d(ggml_metal_library_t lib, const ggml_tensor * op) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d(ggml_metal_library_t lib, const ggml_tensor * op) {
     assert(op->op == GGML_OP_PAD_REFLECT_1D);
 
     char base[256];
@@ -1633,17 +1547,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d(ggml_metal_
     snprintf(base, 256, "kernel_pad_reflect_1d_%s", ggml_type_name(op->src[0]->type));
     snprintf(name, 256, "%s", base);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange(ggml_metal_library_t lib, const ggml_tensor * op) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange(ggml_metal_library_t lib, const ggml_tensor * op) {
     assert(op->op == GGML_OP_ARANGE);
 
     char base[256];
@@ -1652,17 +1564,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange(ggml_metal_library_
     snprintf(base, 256, "kernel_arange_%s", ggml_type_name(op->type));
     snprintf(name, 256, "%s", base);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const ggml_tensor * op) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const ggml_tensor * op) {
     assert(op->op == GGML_OP_TIMESTEP_EMBEDDING);
 
     char base[256];
@@ -1671,17 +1581,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_me
     snprintf(base, 256, "kernel_timestep_embedding_%s", ggml_type_name(op->src[0]->type));
     snprintf(name, 256, "%s", base);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) {
     assert(op->op == GGML_OP_OPT_STEP_ADAMW);
 
     char base[256];
@@ -1690,17 +1598,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_
     snprintf(base, 256, "kernel_opt_step_adamw_%s", ggml_type_name(op->src[0]->type));
     snprintf(name, 256, "%s", base);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) {
+ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) {
     assert(op->op == GGML_OP_OPT_STEP_SGD);
 
     char base[256];
@@ -1709,12 +1615,10 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_li
     snprintf(base, 256, "kernel_opt_step_sgd_%s", ggml_type_name(op->src[0]->type));
     snprintf(name, 256, "%s", base);
 
-    ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
-    if (res) {
-        return res;
+    ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
+    if (!res.pipeline) {
+        res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
     }
 
-    res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
-
     return res;
 }
index 3976e622b9b9a672e3e100cdc6771300ca58934c..17baef2017f3cb199a22fc88e2f83bbbe2d0f6c1 100644 (file)
@@ -35,20 +35,6 @@ typedef struct ggml_metal_pipeline * ggml_metal_pipeline_t;
 ggml_metal_pipeline_t ggml_metal_pipeline_init(void);
 void ggml_metal_pipeline_free(ggml_metal_pipeline_t pipeline);
 
-void ggml_metal_pipeline_set_nsg(ggml_metal_pipeline_t pipeline, int nsg);
-int  ggml_metal_pipeline_get_nsg(ggml_metal_pipeline_t pipeline);
-
-void ggml_metal_pipeline_set_nr0(ggml_metal_pipeline_t pipeline, int nr0);
-int  ggml_metal_pipeline_get_nr0(ggml_metal_pipeline_t pipeline);
-
-void ggml_metal_pipeline_set_nr1(ggml_metal_pipeline_t pipeline, int nr1);
-int  ggml_metal_pipeline_get_nr1(ggml_metal_pipeline_t pipeline);
-
-void   ggml_metal_pipeline_set_smem(ggml_metal_pipeline_t pipeline, size_t smem);
-size_t ggml_metal_pipeline_get_smem(ggml_metal_pipeline_t pipeline);
-
-int ggml_metal_pipeline_max_theads_per_threadgroup(ggml_metal_pipeline_t pipeline);
-
 // a collection of pipelines
 typedef struct ggml_metal_pipelines * ggml_metal_pipelines_t;
 
@@ -58,6 +44,19 @@ void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls);
 void                  ggml_metal_pipelines_add(ggml_metal_pipelines_t ppls, const char * name, ggml_metal_pipeline_t pipeline);
 ggml_metal_pipeline_t ggml_metal_pipelines_get(ggml_metal_pipelines_t ppls, const char * name);
 
+struct ggml_metal_pipeline_with_params {
+    ggml_metal_pipeline_t pipeline;
+
+    int nsg;
+
+    int nr0;
+    int nr1;
+
+    size_t smem;
+};
+
+int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline);
+
 //
 // MTLCommandBuffer wrapper
 //
@@ -76,7 +75,7 @@ void ggml_metal_encoder_free(ggml_metal_encoder_t encoder);
 void ggml_metal_encoder_debug_group_push(ggml_metal_encoder_t encoder, const char * name);
 void ggml_metal_encoder_debug_group_pop (ggml_metal_encoder_t encoder);
 
-void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, ggml_metal_pipeline_t pipeline);
+void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, struct ggml_metal_pipeline_with_params pipeline);
 
 void ggml_metal_encoder_set_bytes (ggml_metal_encoder_t encoder, void * data, size_t size, int idx);
 void ggml_metal_encoder_set_buffer(ggml_metal_encoder_t encoder, struct ggml_metal_buffer_id buffer, int idx);
@@ -100,66 +99,66 @@ ggml_metal_library_t ggml_metal_library_init_from_source(ggml_metal_device_t dev
 
 void ggml_metal_library_free(ggml_metal_library_t lib);
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline    (ggml_metal_library_t lib, const char * name);
-ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv);
-
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_base              (ggml_metal_library_t lib, enum ggml_op op);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cpy               (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d           (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_get_rows          (ggml_metal_library_t lib, enum ggml_type tsrc);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows          (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat            (ggml_metal_library_t lib, enum ggml_type tsrc);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary             (ggml_metal_library_t lib, const struct ggml_tensor * op);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu               (ggml_metal_library_t lib, const struct ggml_tensor * op);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum               (ggml_metal_library_t lib, const struct ggml_tensor * op);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows          (ggml_metal_library_t lib, const struct ggml_tensor * op);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_blk        (ggml_metal_library_t lib, const struct ggml_tensor * op);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add        (ggml_metal_library_t lib, const struct ggml_tensor * op);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max          (ggml_metal_library_t lib, const struct ggml_tensor * op);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv          (ggml_metal_library_t lib, const struct ggml_tensor * op);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan          (ggml_metal_library_t lib, const struct ggml_tensor * op);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv              (ggml_metal_library_t lib, const struct ggml_tensor * op);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext        (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm            (ggml_metal_library_t lib, const struct ggml_tensor * op);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv            (ggml_metal_library_t lib, const struct ggml_tensor * op);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0    (ggml_metal_library_t lib, int ne02, int ne20);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id         (ggml_metal_library_t lib, const struct ggml_tensor * op);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id         (ggml_metal_library_t lib, const struct ggml_tensor * op);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax            (ggml_metal_library_t lib, const struct ggml_tensor * op);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort           (ggml_metal_library_t lib, const struct ggml_tensor * op);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge     (ggml_metal_library_t lib, const struct ggml_tensor * op);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k             (ggml_metal_library_t lib, const struct ggml_tensor * op);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_top_k_merge       (ggml_metal_library_t lib, const struct ggml_tensor * op);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin               (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm           (ggml_metal_library_t lib, const struct ggml_tensor * op);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm        (ggml_metal_library_t lib, const struct ggml_tensor * op);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm              (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rope              (ggml_metal_library_t lib, const struct ggml_tensor * op);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_im2col            (ggml_metal_library_t lib, const struct ggml_tensor * op);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_conv_2d           (ggml_metal_library_t lib, const struct ggml_tensor * op);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_upscale           (ggml_metal_library_t lib, const struct ggml_tensor * op);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad               (ggml_metal_library_t lib, const struct ggml_tensor * op);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d    (ggml_metal_library_t lib, const struct ggml_tensor * op);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange            (ggml_metal_library_t lib, const struct ggml_tensor * op);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw    (ggml_metal_library_t lib, const struct ggml_tensor * op);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd      (ggml_metal_library_t lib, const struct ggml_tensor * op);
-
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline    (ggml_metal_library_t lib, const char * name);
+struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv);
+
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_base              (ggml_metal_library_t lib, enum ggml_op op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cpy               (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pool_2d           (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_get_rows          (ggml_metal_library_t lib, enum ggml_type tsrc);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_set_rows          (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat            (ggml_metal_library_t lib, enum ggml_type tsrc);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary             (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_glu               (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum               (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows          (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_blk        (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_cumsum_add        (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_soft_max          (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv          (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan          (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv              (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext        (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm            (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv            (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id_map0    (ggml_metal_library_t lib, int ne02, int ne20);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm_id         (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_id         (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argmax            (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort           (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge     (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k             (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge       (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin               (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm           (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm        (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm              (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope              (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col            (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_2d           (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_upscale           (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad               (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_pad_reflect_1d    (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange            (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw    (ggml_metal_library_t lib, const struct ggml_tensor * op);
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd      (ggml_metal_library_t lib, const struct ggml_tensor * op);
+
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad(
         ggml_metal_library_t lib,
         const struct ggml_tensor * op,
         bool    has_mask,
         int32_t ncpsg);
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk(
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_blk(
         ggml_metal_library_t lib,
         const struct ggml_tensor * op,
         int32_t nqptg,
         int32_t ncpsg);
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext(
         ggml_metal_library_t lib,
         const struct ggml_tensor * op,
         bool    has_mask,
@@ -169,7 +168,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
         bool    has_kvpad,
         int32_t nsg);
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec(
         ggml_metal_library_t lib,
         const struct ggml_tensor * op,
         bool    has_mask,
@@ -180,7 +179,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
         int32_t nsg,
         int32_t nwg);
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
         ggml_metal_library_t lib,
         const struct ggml_tensor * op,
         int32_t dv,
index 4d2bfcf91c644ad4a3b1225b106ffd687781d48d..d22672a8169ca4c96c66daf9f59aba565fad0cfe 100644 (file)
@@ -75,14 +75,6 @@ void ggml_metal_cv_set_bool(ggml_metal_cv_t cv, bool value, int32_t idx) {
 
 struct ggml_metal_pipeline {
     id<MTLComputePipelineState> obj;
-
-    // suggested dispatch sizes
-    int nsg;
-
-    int nr0;
-    int nr1;
-
-    size_t smem;
 };
 
 ggml_metal_pipeline_t ggml_metal_pipeline_init(void) {
@@ -90,10 +82,6 @@ ggml_metal_pipeline_t ggml_metal_pipeline_init(void) {
 
     *res = (struct ggml_metal_pipeline) {
         /*.obj  =*/ nil,
-        /*.nsg  =*/ 0,
-        /*.nr0  =*/ 0,
-        /*.nr1  =*/ 0,
-        /*.smem =*/ 0,
     };
 
     return res;
@@ -105,40 +93,8 @@ void ggml_metal_pipeline_free(ggml_metal_pipeline_t pipeline) {
     free(pipeline);
 }
 
-void ggml_metal_pipeline_set_nsg(ggml_metal_pipeline_t pipeline, int nsg) {
-    pipeline->nsg = nsg;
-}
-
-int ggml_metal_pipeline_get_nsg(ggml_metal_pipeline_t pipeline) {
-    return pipeline->nsg;
-}
-
-void ggml_metal_pipeline_set_nr0(ggml_metal_pipeline_t pipeline, int nr0) {
-    pipeline->nr0 = nr0;
-}
-
-int ggml_metal_pipeline_get_nr0(ggml_metal_pipeline_t pipeline) {
-    return pipeline->nr0;
-}
-
-void ggml_metal_pipeline_set_nr1(ggml_metal_pipeline_t pipeline, int nr1) {
-    pipeline->nr1 = nr1;
-}
-
-int ggml_metal_pipeline_get_nr1(ggml_metal_pipeline_t pipeline) {
-    return pipeline->nr1;
-}
-
-void   ggml_metal_pipeline_set_smem(ggml_metal_pipeline_t pipeline, size_t smem) {
-    pipeline->smem = smem;
-}
-
-size_t ggml_metal_pipeline_get_smem(ggml_metal_pipeline_t pipeline) {
-    return pipeline->smem;
-}
-
-int ggml_metal_pipeline_max_theads_per_threadgroup(ggml_metal_pipeline_t pipeline) {
-    return pipeline->obj.maxTotalThreadsPerThreadgroup;
+int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline) {
+    return pipeline.pipeline->obj.maxTotalThreadsPerThreadgroup;
 }
 
 struct ggml_metal_library {
@@ -389,28 +345,42 @@ void ggml_metal_library_free(ggml_metal_library_t lib) {
     free(lib);
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) {
+struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_metal_library_t lib, const char * name) {
     [lib->lock lock];
 
-    ggml_metal_pipeline_t res = ggml_metal_pipelines_get(lib->pipelines, name);
+    struct ggml_metal_pipeline_with_params res = {
+        /*.pipeline =*/ nil,
+        /*.nr0      =*/ 0,
+        /*.nr1      =*/ 0,
+        /*.nsg      =*/ 0,
+        /*.smem     =*/ 0,
+    };
+
+    res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name);
 
     [lib->lock unlock];
 
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) {
+struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) {
+    struct ggml_metal_pipeline_with_params res = {
+        /*.pipeline =*/ nil,
+        /*.nr0      =*/ 0,
+        /*.nr1      =*/ 0,
+        /*.nsg      =*/ 0,
+        /*.smem     =*/ 0,
+    };
+
     [lib->lock lock];
 
-    ggml_metal_pipeline_t res = ggml_metal_pipelines_get(lib->pipelines, name);
-    if (res) {
+    res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name);
+    if (res.pipeline) {
         [lib->lock unlock];
 
         return res;
     }
 
-    res = ggml_metal_pipeline_init();
-
     @autoreleasepool {
         NSError * error = nil;
 
@@ -432,26 +402,43 @@ ggml_metal_pipeline_t ggml_metal_library_compile_pipeline(ggml_metal_library_t l
                 GGML_LOG_ERROR("%s: %s\n", __func__, [[error description] UTF8String]);
             }
 
-            return nil;
+            return res;
         }
 
-        res->obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error];
+        id<MTLComputePipelineState> obj = [lib->device newComputePipelineStateWithFunction:mtl_function error:&error];
 
         [mtl_function release];
 
-        GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name, (void *) res->obj,
-                (int) res->obj.maxTotalThreadsPerThreadgroup,
-                (int) res->obj.threadExecutionWidth);
+        if (!obj) {
+            [lib->lock unlock];
+
+            GGML_LOG_ERROR("%s: failed to create pipeline state: base = '%s', name = '%s'\n", __func__, base, name);
+            if (error) {
+                GGML_LOG_ERROR("%s: %s\n", __func__, [[error description] UTF8String]);
+            }
+
+            return res;
+        }
+
+        GGML_LOG_DEBUG("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, name,
+                (void *) obj,
+                (int)    obj.maxTotalThreadsPerThreadgroup,
+                (int)    obj.threadExecutionWidth);
+
+        if (obj.maxTotalThreadsPerThreadgroup == 0 || obj.threadExecutionWidth == 0) {
+            [obj release];
 
-        if (res->obj.maxTotalThreadsPerThreadgroup == 0 || res->obj.threadExecutionWidth == 0) {
             [lib->lock unlock];
 
             GGML_LOG_ERROR("%s: incompatible pipeline %s\n", __func__, name);
 
-            return nil;
+            return res;
         }
 
-        ggml_metal_pipelines_add(lib->pipelines, name, res);
+        res.pipeline = ggml_metal_pipeline_init();
+        res.pipeline->obj = obj;
+
+        ggml_metal_pipelines_add(lib->pipelines, name, res.pipeline);
     }
 
     [lib->lock unlock];
@@ -496,8 +483,8 @@ void ggml_metal_encoder_debug_group_pop (ggml_metal_encoder_t encoder) {
     [encoder->obj popDebugGroup];
 }
 
-void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, ggml_metal_pipeline_t pipeline) {
-    [encoder->obj setComputePipelineState:pipeline->obj];
+void ggml_metal_encoder_set_pipeline(ggml_metal_encoder_t encoder, struct ggml_metal_pipeline_with_params pipeline) {
+    [encoder->obj setComputePipelineState:pipeline.pipeline->obj];
 }
 
 void ggml_metal_encoder_set_bytes(ggml_metal_encoder_t encoder, void * data, size_t size, int idx) {
@@ -622,8 +609,8 @@ ggml_metal_device_t ggml_metal_device_init(void) {
                     GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__);
                     dev->props.has_tensor = false;
                 } else {
-                    ggml_metal_pipeline_t ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil);
-                    if (!ppl) {
+                    struct ggml_metal_pipeline_with_params ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil);
+                    if (!ppl.pipeline) {
                         GGML_LOG_WARN("%s: - the tensor API is not supported in this environment - disabling\n", __func__);
                         dev->props.has_tensor = false;
                     }
@@ -672,8 +659,8 @@ ggml_metal_device_t ggml_metal_device_init(void) {
                     GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__);
                     dev->props.has_bfloat = false;
                 } else {
-                    ggml_metal_pipeline_t ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil);
-                    if (!ppl) {
+                    struct ggml_metal_pipeline_with_params ppl = ggml_metal_library_compile_pipeline(lib, "dummy_kernel", "dummy_kernel", nil);
+                    if (!ppl.pipeline) {
                         GGML_LOG_WARN("%s: - the tensor API does not support bfloat - disabling bfloat support\n", __func__);
                         dev->props.has_bfloat = false;
                     }
index 9871e976f23d149b30ae459b01622f52865f4231..edb227a210071350fa664f9ae28591405dec4c2e 100644 (file)
@@ -524,7 +524,7 @@ int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) {
         /*.dim  =*/ dim,
     };
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT);
+    auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_CONCAT);
 
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
@@ -550,7 +550,7 @@ int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
     GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type);
+    auto pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type);
 
     ggml_metal_kargs_repeat args = {
         /*.ne00 =*/ ne00,
@@ -616,7 +616,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
         // TODO: make a simpler cpy_bytes kernel
 
         //const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
-        ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
+        auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
 
         ggml_metal_kargs_cpy args = {
             /*.nk0  =*/ ne00,
@@ -679,7 +679,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
         /*.o1   =*/ { 0 },
     };
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false);
+    auto pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false);
 
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
@@ -721,7 +721,7 @@ int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
         n /= 4;
     }
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
+    auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
 
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
@@ -760,7 +760,7 @@ int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
         n /= 4;
     }
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
+    auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
 
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
@@ -789,7 +789,7 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
         n /= 4;
     }
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
+    auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
 
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), 0);
@@ -817,7 +817,7 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
         GGML_ASSERT(ggml_are_same_shape(op->src[0], op->src[1]));
     }
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_glu(lib, op);
+    auto pipeline = ggml_metal_library_get_pipeline_glu(lib, op);
 
     const int32_t swp = ggml_get_op_params_i32(op, 1);
     const float alpha = ggml_get_op_params_f32(op, 2);
@@ -870,7 +870,7 @@ int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
         /*.np =*/ n,
     };
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
+    auto pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
 
     int nth = 32; // SIMD width
 
@@ -925,7 +925,7 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
         /*.nb3  =*/ nb3,
     };
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
+    auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
 
     int nth = 32; // SIMD width
 
@@ -936,7 +936,7 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
     nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
     nth = std::min(nth, ne00);
 
-    const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
+    const size_t smem = pipeline.smem;
 
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
@@ -963,7 +963,7 @@ int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
     GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
-    ggml_metal_pipeline_t pipeline_blk = ggml_metal_library_get_pipeline_cumsum_blk(lib, op);
+    auto pipeline_blk = ggml_metal_library_get_pipeline_cumsum_blk(lib, op);
 
     int nth = 1;
     while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_blk)) {
@@ -1060,7 +1060,7 @@ int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) {
         ggml_metal_op_concurrency_reset(ctx);
 
         {
-            ggml_metal_pipeline_t pipeline_add = ggml_metal_library_get_pipeline_cumsum_add(lib, op);
+            auto pipeline_add = ggml_metal_library_get_pipeline_cumsum_add(lib, op);
 
             ggml_metal_kargs_cumsum_add args = {
                 /*.ne00 =*/ ne00,
@@ -1106,7 +1106,7 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
     GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
+    auto pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
 
     ggml_metal_kargs_get_rows args = {
         /*.ne00t =*/ ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00,
@@ -1151,7 +1151,7 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
     GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
+    auto pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
 
     const int32_t nk0 = ne0/ggml_blck_size(op->type);
 
@@ -1252,7 +1252,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
         /*.n_head_log2 =*/ n_head_log2,
     };
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_soft_max(lib, op);
+    auto pipeline = ggml_metal_library_get_pipeline_soft_max(lib, op);
 
     int nth = 32; // SIMD width
 
@@ -1266,7 +1266,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
         }
     }
 
-    const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
+    const size_t smem = pipeline.smem;
 
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
@@ -1322,7 +1322,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
         /*.nb2  =*/ nb2,
     };
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op);
+    auto pipeline = ggml_metal_library_get_pipeline_ssm_conv(lib, op);
 
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
@@ -1409,11 +1409,11 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
         /*.nb0          =*/ nb0,
     };
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);
+    auto pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);
 
     GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
 
-    const size_t sms = ggml_metal_pipeline_get_smem(pipeline);
+    const size_t smem = pipeline.smem;
 
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
@@ -1426,7 +1426,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
     ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[6]), 7);
     ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         8);
 
-    ggml_metal_encoder_set_threadgroup_memory_size(enc, sms, 0);
+    ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
 
     ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
 
@@ -1449,7 +1449,7 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
     const int64_t C = op->ne[0];
     const int64_t H = op->src[0]->ne[1];
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op);
+    auto pipeline = ggml_metal_library_get_pipeline_rwkv(lib, op);
 
     int ida = 0;
 
@@ -1485,7 +1485,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
     GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
+    auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
 
     GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0);
 
@@ -1592,7 +1592,7 @@ int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
         /* .np = */ np
     };
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool);
+    auto pipeline = ggml_metal_library_get_pipeline_pool_2d(lib, op, op_pool);
 
     const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), (int) np);
     const int ntg = (np + nth - 1) / nth;
@@ -1701,7 +1701,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
                 GGML_ABORT("unsupported ne11");
         };
 
-        ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg);
+        auto pipeline = ggml_metal_library_get_pipeline_mul_mv_ext(lib, op->src[0]->type, op->src[1]->type, nsg, nxpsg, r1ptg);
 
         ggml_metal_kargs_mul_mv_ext args = {
             /*.ne00  =*/ ne00,
@@ -1748,7 +1748,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
         //    default: break;
         //}
 
-        ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op);
+        auto pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op);
 
         ggml_metal_kargs_mul_mm args = {
             /*.ne00 =*/ ne00,
@@ -1773,18 +1773,18 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
         ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
         ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op),         3);
 
-        const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
+        const size_t smem = pipeline.smem;
 
         ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
         ggml_metal_encoder_dispatch_threadgroups(enc, ((ne11 + 31)/32), ((ne01 + 63)/64), ne12*ne13, 128, 1, 1);
     } else {
-        ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op);
+        auto pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op);
 
-        const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
-        const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
-        const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
+        const int nr0 = pipeline.nr0;
+        const int nr1 = pipeline.nr1;
+        const int nsg = pipeline.nsg;
 
-        const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
+        const size_t smem = pipeline.smem;
 
         ggml_metal_kargs_mul_mv args = {
             /*.ne00 =*/ ne00,
@@ -1915,9 +1915,9 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
                 nb21,
             };
 
-            ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20);
+            auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id_map0(lib, ne02, ne20);
 
-            const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
+            const size_t smem = pipeline.smem;
 
             GGML_ASSERT(ne02 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
 
@@ -1938,7 +1938,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
         ggml_metal_op_concurrency_reset(ctx);
 
         {
-            ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op);
+            auto pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op);
 
             ggml_metal_kargs_mul_mm_id args = {
                 /*.ne00  =*/ ne00,
@@ -1967,20 +1967,20 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
             ggml_metal_encoder_set_buffer  (enc, bid_ids,  4);
             ggml_metal_encoder_set_buffer  (enc, bid_dst,  5);
 
-            const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
+            const size_t smem = pipeline.smem;
 
             ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
 
             ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1);
         }
     } else {
-        ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
+        auto pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
 
-        const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
-        const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
-        const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
+        const int nr0 = pipeline.nr0;
+        const int nr1 = pipeline.nr1;
+        const int nsg = pipeline.nsg;
 
-        const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
+        const size_t smem = pipeline.smem;
 
         ggml_metal_kargs_mul_mv_id args = {
             /*.nei0 =*/ ne20,
@@ -2064,7 +2064,7 @@ int ggml_metal_op_add_id(ggml_metal_op_t ctx, int idx) {
         /*.nb21 =*/ nb21,
     };
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_ADD_ID);
+    auto pipeline = ggml_metal_library_get_pipeline_base(lib, GGML_OP_ADD_ID);
 
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
@@ -2308,7 +2308,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
                 /*.nb33    =*/nb33,
             };
 
-            ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
+            auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
 
             ggml_metal_encoder_set_pipeline(enc, pipeline0);
             ggml_metal_encoder_set_bytes   (enc, &args0, sizeof(args0), 0);
@@ -2339,7 +2339,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
                 /*.nb33 =*/ nb33,
             };
 
-            ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg);
+            auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg);
 
             ggml_metal_encoder_set_pipeline(enc, pipeline0);
             ggml_metal_encoder_set_bytes   (enc, &args0, sizeof(args0), 0);
@@ -2424,7 +2424,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
             /*.logit_softcap =*/ logit_softcap,
         };
 
-        ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);
+        auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg);
 
         ggml_metal_encoder_set_pipeline(enc, pipeline);
         ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
@@ -2476,7 +2476,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
                 /*.nb33    =*/nb33,
             };
 
-            ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
+            auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg);
 
             ggml_metal_encoder_set_pipeline(enc, pipeline0);
             ggml_metal_encoder_set_bytes   (enc, &args0, sizeof(args0), 0);
@@ -2578,7 +2578,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
             /*.logit_softcap =*/ logit_softcap,
         };
 
-        ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
+        auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg);
 
         GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
 
@@ -2630,7 +2630,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
                     nrows,
                 };
 
-                ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg);
+                auto pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(lib, op, ne20, nwg);
 
                 ggml_metal_encoder_set_pipeline(enc, pipeline0);
                 ggml_metal_encoder_set_bytes   (enc, &args0, sizeof(args0), 0);
@@ -2762,7 +2762,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
     // the offsets of src1 and all fused buffers are relative to the start of the src1 buffer
     bid_src1.offs = 0;
 
-    ggml_metal_pipeline_t pipeline = nullptr;
+    struct ggml_metal_pipeline_with_params pipeline;
 
     if (ggml_nelements(op->src[1]) == ne10 && ggml_is_contiguous(op->src[1]) && ne00 % 4 == 0 && ne10 % 4 == 0) {
         GGML_ASSERT(ggml_is_contiguous(op->src[0]));
@@ -2835,7 +2835,7 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
         /*.eps    =*/ eps,
     };
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
+    auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
 
     while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
         nth *= 2;
@@ -2844,7 +2844,7 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
     nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
     nth = std::min(nth, ne00/4);
 
-    const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
+    const size_t smem = pipeline.smem;
 
     const int64_t nrows = ggml_nrows(op->src[0]);
 
@@ -2887,7 +2887,7 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
         /*.eps  =*/ eps,
     };
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_group_norm(lib, op);
+    auto pipeline = ggml_metal_library_get_pipeline_group_norm(lib, op);
 
     int nth = 32; // SIMD width
     //while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
@@ -2897,7 +2897,7 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
     //nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
     //nth = std::min(nth, ne00/4);
 
-    const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
+    const size_t smem = pipeline.smem;
 
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
@@ -3022,7 +3022,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
         }
     }
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse);
+    auto pipeline = ggml_metal_library_get_pipeline_norm(lib, op, n_fuse);
 
     int nth = 32; // SIMD width
 
@@ -3033,7 +3033,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
     nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
     nth = std::min(nth, args.ne00_t);
 
-    const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
+    const size_t smem = pipeline.smem;
 
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
@@ -3127,7 +3127,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
         /* src2        =*/ op->src[2] != nullptr,
     };
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_rope(lib, op);
+    auto pipeline = ggml_metal_library_get_pipeline_rope(lib, op);
 
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
@@ -3199,7 +3199,7 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
         /*.KHW  =*/ KH * KW,
     };
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_im2col(lib, op);
+    auto pipeline = ggml_metal_library_get_pipeline_im2col(lib, op);
 
     GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
 
@@ -3270,7 +3270,7 @@ int ggml_metal_op_conv_2d(ggml_metal_op_t ctx, int idx) {
         /*.d1   =*/ d1,
     };
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_2d(lib, op);
+    auto pipeline = ggml_metal_library_get_pipeline_conv_2d(lib, op);
 
     int nth = ggml_metal_pipeline_max_theads_per_threadgroup(pipeline);
     nth = std::min(nth, 256);
@@ -3325,7 +3325,7 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
         /*.nb1 =*/ nb1,
     };
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op);
+    auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_1d(lib, op);
 
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
@@ -3377,7 +3377,7 @@ int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) {
         /*.nb2 =*/ nb2,
     };
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op);
+    auto pipeline = ggml_metal_library_get_pipeline_conv_transpose_2d(lib, op);
 
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
@@ -3433,7 +3433,7 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
         /*.sf3 =*/ sf3
     };
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_upscale(lib, op);
+    auto pipeline = ggml_metal_library_get_pipeline_upscale(lib, op);
 
     const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
 
@@ -3477,7 +3477,7 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
         /*.nb3  =*/ nb3
     };
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pad(lib, op);
+    auto pipeline = ggml_metal_library_get_pipeline_pad(lib, op);
 
     const int nth = std::min(1024, ne0);
 
@@ -3523,7 +3523,7 @@ int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) {
         /*.p1 =*/ ((const int32_t *)(op->op_params))[1]
     };
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op);
+    auto pipeline = ggml_metal_library_get_pipeline_pad_reflect_1d(lib, op);
 
     const int nth = std::min(1024, ne0);
 
@@ -3560,7 +3560,7 @@ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
 
     const int nth = std::min(1024, ne0);
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_arange(lib, op);
+    auto pipeline = ggml_metal_library_get_pipeline_arange(lib, op);
 
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
@@ -3591,7 +3591,7 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) {
         /*.max_period =*/ max_period,
     };
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_timestep_embedding(lib, op);
+    auto pipeline = ggml_metal_library_get_pipeline_timestep_embedding(lib, op);
 
     const int nth = std::max(1, std::min(1024, dim/2));
 
@@ -3621,7 +3621,7 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
         /*.nb01 = */ nb01,
     };
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argmax(lib, op);
+    auto pipeline = ggml_metal_library_get_pipeline_argmax(lib, op);
 
     const int64_t nrows = ggml_nrows(op->src[0]);
 
@@ -3630,7 +3630,7 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
         nth *= 2;
     }
 
-    const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
+    const size_t smem = pipeline.smem;
 
     ggml_metal_encoder_set_pipeline(enc, pipeline);
     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), 0);
@@ -3657,7 +3657,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
     GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
+    auto pipeline = ggml_metal_library_get_pipeline_argsort(lib, op);
 
     // bitonic sort requires the number of elements to be power of 2
     int nth = 1;
@@ -3706,7 +3706,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
 
     ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
 
-    ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op);
+    auto pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op);
 
     int len = nth;
 
@@ -3764,7 +3764,7 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
     GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_top_k(lib, op);
+    auto pipeline = ggml_metal_library_get_pipeline_top_k(lib, op);
 
     // bitonic sort requires the number of elements to be power of 2
     int nth = 1;
@@ -3818,7 +3818,7 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
 
     ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1);
 
-    ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op);
+    auto pipeline_merge = ggml_metal_library_get_pipeline_top_k_merge(lib, op);
 
     int len = args.top_k;
 
@@ -3881,7 +3881,7 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
         /*.slope =*/ slope
     };
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
+    auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
 
     int64_t n = ggml_nelements(op);
 
@@ -3910,7 +3910,7 @@ int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
     GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
+    auto pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
 
     const int64_t np = ggml_nelements(op->src[0]);
     ggml_metal_kargs_opt_step_adamw args = {
@@ -3946,7 +3946,7 @@ int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
     GGML_TENSOR_LOCALS(uint64_t, nb,  op,         nb);
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
+    auto pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
 
     const int64_t np = ggml_nelements(op->src[0]);
     ggml_metal_kargs_opt_step_sgd args = {