]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
metal : refactor kernel loading code (llama/4794)
authorGeorgi Gerganov <redacted>
Sat, 13 Jan 2024 16:03:45 +0000 (18:03 +0200)
committerGeorgi Gerganov <redacted>
Sat, 13 Jan 2024 22:11:44 +0000 (00:11 +0200)
* metal : detect more GPU families

* metal : refactor kernel loading

* metal : set kernel family requirements

* metal : fix kernel init + fix compile options

* metal : take into account simdgroup reduction support

* metal : print only skipped kernels

* metal : fix check for simdgroup reduction support

* metal : check for Metal 3

* metal : free allocations

* metal : normalize encoder:setComputePipelineStatus calls

ggml-ci

* metal : fix Metal3 family check

ggml-ci

* metal : check for simdgroup matrix mul. feature

ggml-ci

ggml-metal.m

index c03624073fb61f6e728e1c30c0e9361ca49e4da0..6c28a7ee32d3f1942ea0155442fc157fb62c3006 100644 (file)
@@ -26,6 +26,8 @@
 
 #define GGML_MAX_CONCUR (2*GGML_DEFAULT_GRAPH_SIZE)
 
+#define GGML_METAL_MAX_KERNELS 256
+
 struct ggml_metal_buffer {
     const char * name;
 
@@ -35,6 +37,134 @@ struct ggml_metal_buffer {
     id<MTLBuffer> metal;
 };
 
+struct ggml_metal_kernel {
+    id<MTLFunction>             function;
+    id<MTLComputePipelineState> pipeline;
+};
+
+enum ggml_metal_kernel_type {
+    GGML_METAL_KERNEL_TYPE_ADD,
+    GGML_METAL_KERNEL_TYPE_ADD_ROW,
+    GGML_METAL_KERNEL_TYPE_MUL,
+    GGML_METAL_KERNEL_TYPE_MUL_ROW,
+    GGML_METAL_KERNEL_TYPE_DIV,
+    GGML_METAL_KERNEL_TYPE_DIV_ROW,
+    GGML_METAL_KERNEL_TYPE_SCALE,
+    GGML_METAL_KERNEL_TYPE_SCALE_4,
+    GGML_METAL_KERNEL_TYPE_TANH,
+    GGML_METAL_KERNEL_TYPE_RELU,
+    GGML_METAL_KERNEL_TYPE_GELU,
+    GGML_METAL_KERNEL_TYPE_GELU_QUICK,
+    GGML_METAL_KERNEL_TYPE_SILU,
+    GGML_METAL_KERNEL_TYPE_SOFT_MAX,
+    GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,
+    GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,
+    GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
+    GGML_METAL_KERNEL_TYPE_RMS_NORM,
+    GGML_METAL_KERNEL_TYPE_GROUP_NORM,
+    GGML_METAL_KERNEL_TYPE_NORM,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
+  //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,
+  //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,
+  //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
+    GGML_METAL_KERNEL_TYPE_ROPE_F32,
+    GGML_METAL_KERNEL_TYPE_ROPE_F16,
+    GGML_METAL_KERNEL_TYPE_ALIBI_F32,
+    GGML_METAL_KERNEL_TYPE_IM2COL_F16,
+    GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
+    GGML_METAL_KERNEL_TYPE_PAD_F32,
+    GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
+    GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
+    GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
+    GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
+    GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
+    GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
+    GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
+    GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
+  //GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
+  //GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
+    GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
+    GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
+    GGML_METAL_KERNEL_TYPE_CONCAT,
+    GGML_METAL_KERNEL_TYPE_SQR,
+    GGML_METAL_KERNEL_TYPE_SUM_ROWS,
+
+    GGML_METAL_KERNEL_TYPE_COUNT
+};
+
 struct ggml_metal_context {
     int n_cb;
 
@@ -50,134 +180,13 @@ struct ggml_metal_context {
     int n_buffers;
     struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
 
+    struct ggml_metal_kernel kernels[GGML_METAL_MAX_KERNELS];
+
     int concur_list[GGML_MAX_CONCUR];
     int concur_list_len;
 
-    // custom kernels
-#define GGML_METAL_DECL_KERNEL(name) \
-    id<MTLFunction>             function_##name; \
-    id<MTLComputePipelineState> pipeline_##name
-
-    GGML_METAL_DECL_KERNEL(add);
-    GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast
-    GGML_METAL_DECL_KERNEL(mul);
-    GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
-    GGML_METAL_DECL_KERNEL(div);
-    GGML_METAL_DECL_KERNEL(div_row);
-    GGML_METAL_DECL_KERNEL(scale);
-    GGML_METAL_DECL_KERNEL(scale_4);
-    GGML_METAL_DECL_KERNEL(tanh);
-    GGML_METAL_DECL_KERNEL(relu);
-    GGML_METAL_DECL_KERNEL(gelu);
-    GGML_METAL_DECL_KERNEL(gelu_quick);
-    GGML_METAL_DECL_KERNEL(silu);
-    GGML_METAL_DECL_KERNEL(soft_max);
-    GGML_METAL_DECL_KERNEL(soft_max_4);
-    GGML_METAL_DECL_KERNEL(diag_mask_inf);
-    GGML_METAL_DECL_KERNEL(diag_mask_inf_8);
-    GGML_METAL_DECL_KERNEL(get_rows_f32);
-    GGML_METAL_DECL_KERNEL(get_rows_f16);
-    GGML_METAL_DECL_KERNEL(get_rows_q4_0);
-    GGML_METAL_DECL_KERNEL(get_rows_q4_1);
-    GGML_METAL_DECL_KERNEL(get_rows_q5_0);
-    GGML_METAL_DECL_KERNEL(get_rows_q5_1);
-    GGML_METAL_DECL_KERNEL(get_rows_q8_0);
-    GGML_METAL_DECL_KERNEL(get_rows_q2_K);
-    GGML_METAL_DECL_KERNEL(get_rows_q3_K);
-    GGML_METAL_DECL_KERNEL(get_rows_q4_K);
-    GGML_METAL_DECL_KERNEL(get_rows_q5_K);
-    GGML_METAL_DECL_KERNEL(get_rows_q6_K);
-    GGML_METAL_DECL_KERNEL(get_rows_i32);
-    GGML_METAL_DECL_KERNEL(get_rows_iq2_xxs);
-    GGML_METAL_DECL_KERNEL(get_rows_iq2_xs);
-    GGML_METAL_DECL_KERNEL(rms_norm);
-    GGML_METAL_DECL_KERNEL(group_norm);
-    GGML_METAL_DECL_KERNEL(norm);
-    GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
-    GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
-    GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
-    GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
-    GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
-    GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
-    GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
-    GGML_METAL_DECL_KERNEL(mul_mv_q5_0_f32);
-    GGML_METAL_DECL_KERNEL(mul_mv_q5_1_f32);
-    GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
-    GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
-    GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
-    GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
-    GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
-    GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
-    GGML_METAL_DECL_KERNEL(mul_mv_iq2_xxs_f32);
-    GGML_METAL_DECL_KERNEL(mul_mv_iq2_xs_f32);
-    GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
-    //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
-    GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
-    //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row);
-    //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4);
-    GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32);
-    GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32);
-    GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32);
-    GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32);
-    GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32);
-    GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32);
-    GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32);
-    GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
-    GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
-    GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
-    GGML_METAL_DECL_KERNEL(mul_mv_id_iq2_xxs_f32);
-    GGML_METAL_DECL_KERNEL(mul_mv_id_iq2_xs_f32);
-    GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
-    GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
-    GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
-    GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
-    GGML_METAL_DECL_KERNEL(mul_mm_q5_0_f32);
-    GGML_METAL_DECL_KERNEL(mul_mm_q5_1_f32);
-    GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
-    GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
-    GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
-    GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
-    GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
-    GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
-    GGML_METAL_DECL_KERNEL(mul_mm_iq2_xxs_f32);
-    GGML_METAL_DECL_KERNEL(mul_mm_iq2_xs_f32);
-    GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32);
-    GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32);
-    GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32);
-    GGML_METAL_DECL_KERNEL(mul_mm_id_q4_1_f32);
-    GGML_METAL_DECL_KERNEL(mul_mm_id_q5_0_f32);
-    GGML_METAL_DECL_KERNEL(mul_mm_id_q5_1_f32);
-    GGML_METAL_DECL_KERNEL(mul_mm_id_q8_0_f32);
-    GGML_METAL_DECL_KERNEL(mul_mm_id_q2_K_f32);
-    GGML_METAL_DECL_KERNEL(mul_mm_id_q3_K_f32);
-    GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32);
-    GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32);
-    GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32);
-    GGML_METAL_DECL_KERNEL(mul_mm_id_iq2_xxs_f32);
-    GGML_METAL_DECL_KERNEL(mul_mm_id_iq2_xs_f32);
-    GGML_METAL_DECL_KERNEL(rope_f32);
-    GGML_METAL_DECL_KERNEL(rope_f16);
-    GGML_METAL_DECL_KERNEL(alibi_f32);
-    GGML_METAL_DECL_KERNEL(im2col_f16);
-    GGML_METAL_DECL_KERNEL(upscale_f32);
-    GGML_METAL_DECL_KERNEL(pad_f32);
-    GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
-    GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
-    GGML_METAL_DECL_KERNEL(leaky_relu_f32);
-    GGML_METAL_DECL_KERNEL(cpy_f32_f16);
-    GGML_METAL_DECL_KERNEL(cpy_f32_f32);
-    GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
-    GGML_METAL_DECL_KERNEL(cpy_f32_q4_0);
-    GGML_METAL_DECL_KERNEL(cpy_f32_q4_1);
-    //GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
-    //GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
-    GGML_METAL_DECL_KERNEL(cpy_f16_f16);
-    GGML_METAL_DECL_KERNEL(cpy_f16_f32);
-    GGML_METAL_DECL_KERNEL(concat);
-    GGML_METAL_DECL_KERNEL(sqr);
-    GGML_METAL_DECL_KERNEL(sum_rows);
-
-#undef GGML_METAL_DECL_KERNEL
+    bool support_simdgroup_reduction;
+    bool support_simdgroup_mm;
 };
 
 // MSL code
@@ -298,19 +307,22 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
                 return NULL;
             }
 
-            MTLCompileOptions* options = nil;
+            // dictionary of preprocessor macros
+            NSMutableDictionary * prep = [NSMutableDictionary dictionary];
+
 #ifdef GGML_QKK_64
-            options = [MTLCompileOptions new];
-            options.preprocessorMacros = @{ @"QK_K" : @(64) };
+            prep[@"QK_K"] = @(64);
 #endif
-            // try to disable fast-math
-            // NOTE: this seems to have no effect whatsoever
-            //       instead, in order to disable fast-math, we have to build default.metallib from the command line
-            //       using xcrun -sdk macosx metal -fno-fast-math -c ggml-metal.metal -o ggml-metal.air
-            //       and go through the "pre-compiled library found" path above
+
+            MTLCompileOptions* options = [MTLCompileOptions new];
+            options.preprocessorMacros = prep;
+
             //[options setFastMathEnabled:false];
 
             ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
+
+            [options release];
+            [prep release];
         }
 
         if (error) {
@@ -323,16 +335,41 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
     // print MTL GPU family:
     GGML_METAL_LOG_INFO("%s: GPU name:   %s\n", __func__, [[ctx->device name] UTF8String]);
 
+    const NSInteger MTLGPUFamilyMetal3 = 5001;
+
     // determine max supported GPU family
     // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
     // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
-    for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
-        if ([ctx->device supportsFamily:i]) {
-            GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
-            break;
+    {
+        for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
+            if ([ctx->device supportsFamily:i]) {
+                GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d  (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
+                break;
+            }
+        }
+
+        for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) {
+            if ([ctx->device supportsFamily:i]) {
+                GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i);
+                break;
+            }
+        }
+
+        for (int i = MTLGPUFamilyMetal3 + 5; i >= MTLGPUFamilyMetal3; --i) {
+            if ([ctx->device supportsFamily:i]) {
+                GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d  (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3 + 3, i);
+                break;
+            }
         }
     }
 
+    ctx->support_simdgroup_reduction  = [ctx->device supportsFamily:MTLGPUFamilyApple7];
+    ctx->support_simdgroup_reduction |= [ctx->device supportsFamily:MTLGPUFamilyMetal3];
+
+    ctx->support_simdgroup_mm = [ctx->device supportsFamily:MTLGPUFamilyApple7];
+
+    GGML_METAL_LOG_INFO("%s: simdgroup reduction support   = %s\n",       __func__, ctx->support_simdgroup_reduction ? "true" : "false");
+    GGML_METAL_LOG_INFO("%s: simdgroup matrix mul. support = %s\n",       __func__, ctx->support_simdgroup_mm ? "true" : "false");
     GGML_METAL_LOG_INFO("%s: hasUnifiedMemory              = %s\n",       __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
     GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize  = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
     if (ctx->device.maxTransferRate != 0) {
@@ -346,141 +383,152 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
     {
         NSError * error = nil;
 
+        for (int i = 0; i < GGML_METAL_MAX_KERNELS; ++i) {
+            ctx->kernels[i].function = nil;
+            ctx->kernels[i].pipeline = nil;
+        }
+
         /*
-        GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) ctx->pipeline_##name, \
-                (int) ctx->pipeline_##name.maxTotalThreadsPerThreadgroup, \
-                (int) ctx->pipeline_##name.threadExecutionWidth); \
+            GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
+                    (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
+                    (int) kernel->pipeline.threadExecutionWidth); \
         */
-#define GGML_METAL_ADD_KERNEL(name) \
-        ctx->function_##name = [ctx->library newFunctionWithName:@"kernel_"#name]; \
-        ctx->pipeline_##name = [ctx->device newComputePipelineStateWithFunction:ctx->function_##name error:&error]; \
-        if (error) { \
-            GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
-            return NULL; \
+#define GGML_METAL_ADD_KERNEL(e, name, supported) \
+        if (supported) { \
+            struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
+            kernel->function = [ctx->library newFunctionWithName:@"kernel_"#name]; \
+            kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:kernel->function error:&error]; \
+            GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
+                    (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
+                    (int) kernel->pipeline.threadExecutionWidth); \
+            if (error) { \
+                GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
+                return NULL; \
+            } \
+        } else { \
+            GGML_METAL_LOG_WARN("%s: skipping %-32s (not supported)\n", __func__, "kernel_"#name); \
         }
 
-        GGML_METAL_ADD_KERNEL(add);
-        GGML_METAL_ADD_KERNEL(add_row);
-        GGML_METAL_ADD_KERNEL(mul);
-        GGML_METAL_ADD_KERNEL(mul_row);
-        GGML_METAL_ADD_KERNEL(div);
-        GGML_METAL_ADD_KERNEL(div_row);
-        GGML_METAL_ADD_KERNEL(scale);
-        GGML_METAL_ADD_KERNEL(scale_4);
-        GGML_METAL_ADD_KERNEL(tanh);
-        GGML_METAL_ADD_KERNEL(relu);
-        GGML_METAL_ADD_KERNEL(gelu);
-        GGML_METAL_ADD_KERNEL(gelu_quick);
-        GGML_METAL_ADD_KERNEL(silu);
-        GGML_METAL_ADD_KERNEL(soft_max);
-        GGML_METAL_ADD_KERNEL(soft_max_4);
-        GGML_METAL_ADD_KERNEL(diag_mask_inf);
-        GGML_METAL_ADD_KERNEL(diag_mask_inf_8);
-        GGML_METAL_ADD_KERNEL(get_rows_f32);
-        GGML_METAL_ADD_KERNEL(get_rows_f16);
-        GGML_METAL_ADD_KERNEL(get_rows_q4_0);
-        GGML_METAL_ADD_KERNEL(get_rows_q4_1);
-        GGML_METAL_ADD_KERNEL(get_rows_q5_0);
-        GGML_METAL_ADD_KERNEL(get_rows_q5_1);
-        GGML_METAL_ADD_KERNEL(get_rows_q8_0);
-        GGML_METAL_ADD_KERNEL(get_rows_q2_K);
-        GGML_METAL_ADD_KERNEL(get_rows_q3_K);
-        GGML_METAL_ADD_KERNEL(get_rows_q4_K);
-        GGML_METAL_ADD_KERNEL(get_rows_q5_K);
-        GGML_METAL_ADD_KERNEL(get_rows_q6_K);
-        GGML_METAL_ADD_KERNEL(get_rows_i32);
-        GGML_METAL_ADD_KERNEL(get_rows_iq2_xxs);
-        GGML_METAL_ADD_KERNEL(get_rows_iq2_xs);
-        GGML_METAL_ADD_KERNEL(rms_norm);
-        GGML_METAL_ADD_KERNEL(group_norm);
-        GGML_METAL_ADD_KERNEL(norm);
-        GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
-        GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
-        GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
-        GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
-        GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
-        GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
-        GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
-        GGML_METAL_ADD_KERNEL(mul_mv_q5_0_f32);
-        GGML_METAL_ADD_KERNEL(mul_mv_q5_1_f32);
-        GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
-        GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
-        GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
-        GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
-        GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
-        GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
-        GGML_METAL_ADD_KERNEL(mul_mv_iq2_xxs_f32);
-        GGML_METAL_ADD_KERNEL(mul_mv_iq2_xs_f32);
-        GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
-        //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
-        GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
-        //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_1row);
-        //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_l4);
-        GGML_METAL_ADD_KERNEL(mul_mv_id_q4_0_f32);
-        GGML_METAL_ADD_KERNEL(mul_mv_id_q4_1_f32);
-        GGML_METAL_ADD_KERNEL(mul_mv_id_q5_0_f32);
-        GGML_METAL_ADD_KERNEL(mul_mv_id_q5_1_f32);
-        GGML_METAL_ADD_KERNEL(mul_mv_id_q8_0_f32);
-        GGML_METAL_ADD_KERNEL(mul_mv_id_q2_K_f32);
-        GGML_METAL_ADD_KERNEL(mul_mv_id_q3_K_f32);
-        GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
-        GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
-        GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
-        GGML_METAL_ADD_KERNEL(mul_mv_id_iq2_xxs_f32);
-        GGML_METAL_ADD_KERNEL(mul_mv_id_iq2_xs_f32);
-        if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
-            GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
-            GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
-            GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
-            GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
-            GGML_METAL_ADD_KERNEL(mul_mm_q5_0_f32);
-            GGML_METAL_ADD_KERNEL(mul_mm_q5_1_f32);
-            GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
-            GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
-            GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
-            GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
-            GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
-            GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
-            GGML_METAL_ADD_KERNEL(mul_mm_iq2_xxs_f32);
-            GGML_METAL_ADD_KERNEL(mul_mm_iq2_xs_f32);
-            GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32);
-            GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32);
-            GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32);
-            GGML_METAL_ADD_KERNEL(mul_mm_id_q4_1_f32);
-            GGML_METAL_ADD_KERNEL(mul_mm_id_q5_0_f32);
-            GGML_METAL_ADD_KERNEL(mul_mm_id_q5_1_f32);
-            GGML_METAL_ADD_KERNEL(mul_mm_id_q8_0_f32);
-            GGML_METAL_ADD_KERNEL(mul_mm_id_q2_K_f32);
-            GGML_METAL_ADD_KERNEL(mul_mm_id_q3_K_f32);
-            GGML_METAL_ADD_KERNEL(mul_mm_id_q4_K_f32);
-            GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32);
-            GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32);
-            GGML_METAL_ADD_KERNEL(mul_mm_id_iq2_xxs_f32);
-            GGML_METAL_ADD_KERNEL(mul_mm_id_iq2_xs_f32);
-        }
-        GGML_METAL_ADD_KERNEL(rope_f32);
-        GGML_METAL_ADD_KERNEL(rope_f16);
-        GGML_METAL_ADD_KERNEL(alibi_f32);
-        GGML_METAL_ADD_KERNEL(im2col_f16);
-        GGML_METAL_ADD_KERNEL(upscale_f32);
-        GGML_METAL_ADD_KERNEL(pad_f32);
-        GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
-        GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
-        GGML_METAL_ADD_KERNEL(leaky_relu_f32);
-        GGML_METAL_ADD_KERNEL(cpy_f32_f16);
-        GGML_METAL_ADD_KERNEL(cpy_f32_f32);
-        GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
-        GGML_METAL_ADD_KERNEL(cpy_f32_q4_0);
-        GGML_METAL_ADD_KERNEL(cpy_f32_q4_1);
-        //GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
-        //GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
-        GGML_METAL_ADD_KERNEL(cpy_f16_f16);
-        GGML_METAL_ADD_KERNEL(cpy_f16_f32);
-        GGML_METAL_ADD_KERNEL(concat);
-        GGML_METAL_ADD_KERNEL(sqr);
-        GGML_METAL_ADD_KERNEL(sum_rows);
-
-#undef GGML_METAL_ADD_KERNEL
+        // simd_sum and simd_max requires MTLGPUFamilyApple7
+
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD,                       add,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW,                   add_row,                true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL,                       mul,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW,                   mul_row,                true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV,                       div,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW,                   div_row,                true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE,                     scale,                  true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4,                   scale_4,                true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH,                      tanh,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU,                      relu,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU,                      gelu,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK,                gelu_quick,             true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU,                      silu,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX,                  soft_max,               ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4,                soft_max_4,             ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,             diag_mask_inf,          true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,           diag_mask_inf_8,        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,              get_rows_f32,           true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,              get_rows_f16,           true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,             get_rows_q4_0,          true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,             get_rows_q4_1,          true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,             get_rows_q5_0,          true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,             get_rows_q5_1,          true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,             get_rows_q8_0,          true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,             get_rows_q2_K,          true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,             get_rows_q3_K,          true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,             get_rows_q4_K,          true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K,             get_rows_q5_K,          true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K,             get_rows_q6_K,          true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS,          get_rows_iq2_xxs,       true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,           get_rows_iq2_xs,        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,              get_rows_i32,           true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM,                  rms_norm,               ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM,                group_norm,             ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM,                      norm,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,            mul_mv_f32_f32,         ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,            mul_mv_f16_f16,         ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,            mul_mv_f16_f32,         ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,       mul_mv_f16_f32_1row,    ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,         mul_mv_f16_f32_l4,      ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,           mul_mv_q4_0_f32,        ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,           mul_mv_q4_1_f32,        ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,           mul_mv_q5_0_f32,        ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,           mul_mv_q5_1_f32,        ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,           mul_mv_q8_0_f32,        ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,           mul_mv_q2_K_f32,        ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,           mul_mv_q3_K_f32,        ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,           mul_mv_q4_K_f32,        ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32,           mul_mv_q5_K_f32,        ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,           mul_mv_q6_K_f32,        ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,        mul_mv_iq2_xxs_f32,     ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,         mul_mv_iq2_xs_f32,      ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,         mul_mv_id_f32_f32,      ctx->support_simdgroup_reduction);
+      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,         mul_mv_id_f16_f16,      ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,         mul_mv_id_f16_f32,      ctx->support_simdgroup_reduction);
+      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,    mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction);
+      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,      mul_mv_id_f16_f32_l4,   ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,        mul_mv_id_q4_0_f32,     ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,        mul_mv_id_q4_1_f32,     ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,        mul_mv_id_q5_0_f32,     ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,        mul_mv_id_q5_1_f32,     ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,        mul_mv_id_q8_0_f32,     ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,        mul_mv_id_q2_K_f32,     ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,        mul_mv_id_q3_K_f32,     ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,        mul_mv_id_q4_K_f32,     ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32,        mul_mv_id_q5_K_f32,     ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,        mul_mv_id_q6_K_f32,     ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,     mul_mv_id_iq2_xxs_f32,  ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,      mul_mv_id_iq2_xs_f32,   ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,            mul_mm_f32_f32,         ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,            mul_mm_f16_f32,         ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,           mul_mm_q4_0_f32,        ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,           mul_mm_q4_1_f32,        ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,           mul_mm_q5_0_f32,        ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,           mul_mm_q5_1_f32,        ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,           mul_mm_q8_0_f32,        ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,           mul_mm_q2_K_f32,        ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,           mul_mm_q3_K_f32,        ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,           mul_mm_q4_K_f32,        ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32,           mul_mm_q5_K_f32,        ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,           mul_mm_q6_K_f32,        ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,        mul_mm_iq2_xxs_f32,     ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,         mul_mm_iq2_xs_f32,      ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,         mul_mm_id_f32_f32,      ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,         mul_mm_id_f16_f32,      ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,        mul_mm_id_q4_0_f32,     ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,        mul_mm_id_q4_1_f32,     ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,        mul_mm_id_q5_0_f32,     ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,        mul_mm_id_q5_1_f32,     ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,        mul_mm_id_q8_0_f32,     ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,        mul_mm_id_q2_K_f32,     ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,        mul_mm_id_q3_K_f32,     ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,        mul_mm_id_q4_K_f32,     ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,        mul_mm_id_q5_K_f32,     ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,        mul_mm_id_q6_K_f32,     ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,     mul_mm_id_iq2_xxs_f32,  ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,      mul_mm_id_iq2_xs_f32,   ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32,                  rope_f32,               true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16,                  rope_f16,               true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32,                 alibi_f32,              true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16,                im2col_f16,             true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32,               upscale_f32,            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32,                   pad_f32,                true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,       argsort_f32_i32_asc,    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,      argsort_f32_i32_desc,   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,            leaky_relu_f32,         true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16,               cpy_f32_f16,            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32,               cpy_f32_f32,            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,              cpy_f32_q8_0,           true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,              cpy_f32_q4_0,           true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,              cpy_f32_q4_1,           true);
+      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,              cpy_f32_q5_0,           true);
+      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,              cpy_f32_q5_1,           true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16,               cpy_f16_f16,            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32,               cpy_f16_f32,            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT,                    concat,                 true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR,                       sqr,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                  sum_rows,               true);
     }
 
     return ctx;
@@ -488,137 +536,21 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
 
 void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_LOG_INFO("%s: deallocating\n", __func__);
-#define GGML_METAL_DEL_KERNEL(name) \
-    [ctx->function_##name release]; \
-    [ctx->pipeline_##name release];
-
-    GGML_METAL_DEL_KERNEL(add);
-    GGML_METAL_DEL_KERNEL(add_row);
-    GGML_METAL_DEL_KERNEL(mul);
-    GGML_METAL_DEL_KERNEL(mul_row);
-    GGML_METAL_DEL_KERNEL(div);
-    GGML_METAL_DEL_KERNEL(div_row);
-    GGML_METAL_DEL_KERNEL(scale);
-    GGML_METAL_DEL_KERNEL(scale_4);
-    GGML_METAL_DEL_KERNEL(tanh);
-    GGML_METAL_DEL_KERNEL(relu);
-    GGML_METAL_DEL_KERNEL(gelu);
-    GGML_METAL_DEL_KERNEL(gelu_quick);
-    GGML_METAL_DEL_KERNEL(silu);
-    GGML_METAL_DEL_KERNEL(soft_max);
-    GGML_METAL_DEL_KERNEL(soft_max_4);
-    GGML_METAL_DEL_KERNEL(diag_mask_inf);
-    GGML_METAL_DEL_KERNEL(diag_mask_inf_8);
-    GGML_METAL_DEL_KERNEL(get_rows_f32);
-    GGML_METAL_DEL_KERNEL(get_rows_f16);
-    GGML_METAL_DEL_KERNEL(get_rows_q4_0);
-    GGML_METAL_DEL_KERNEL(get_rows_q4_1);
-    GGML_METAL_DEL_KERNEL(get_rows_q5_0);
-    GGML_METAL_DEL_KERNEL(get_rows_q5_1);
-    GGML_METAL_DEL_KERNEL(get_rows_q8_0);
-    GGML_METAL_DEL_KERNEL(get_rows_q2_K);
-    GGML_METAL_DEL_KERNEL(get_rows_q3_K);
-    GGML_METAL_DEL_KERNEL(get_rows_q4_K);
-    GGML_METAL_DEL_KERNEL(get_rows_q5_K);
-    GGML_METAL_DEL_KERNEL(get_rows_q6_K);
-    GGML_METAL_DEL_KERNEL(get_rows_i32);
-    GGML_METAL_DEL_KERNEL(get_rows_iq2_xxs);
-    GGML_METAL_DEL_KERNEL(get_rows_iq2_xs);
-    GGML_METAL_DEL_KERNEL(rms_norm);
-    GGML_METAL_DEL_KERNEL(group_norm);
-    GGML_METAL_DEL_KERNEL(norm);
-    GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
-    GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
-    GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
-    GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
-    GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
-    GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
-    GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
-    GGML_METAL_DEL_KERNEL(mul_mv_q5_0_f32);
-    GGML_METAL_DEL_KERNEL(mul_mv_q5_1_f32);
-    GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
-    GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
-    GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
-    GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
-    GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
-    GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
-    GGML_METAL_DEL_KERNEL(mul_mv_iq2_xxs_f32);
-    GGML_METAL_DEL_KERNEL(mul_mv_iq2_xs_f32);
-    GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
-    //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
-    GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
-    //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row);
-    //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4);
-    GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32);
-    GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32);
-    GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32);
-    GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32);
-    GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32);
-    GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32);
-    GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32);
-    GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
-    GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
-    GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
-    GGML_METAL_DEL_KERNEL(mul_mv_id_iq2_xxs_f32);
-    GGML_METAL_DEL_KERNEL(mul_mv_id_iq2_xs_f32);
-    if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
-        GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
-        GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
-        GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
-        GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
-        GGML_METAL_DEL_KERNEL(mul_mm_q5_0_f32);
-        GGML_METAL_DEL_KERNEL(mul_mm_q5_1_f32);
-        GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
-        GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
-        GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
-        GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
-        GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
-        GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
-        GGML_METAL_DEL_KERNEL(mul_mm_iq2_xxs_f32);
-        GGML_METAL_DEL_KERNEL(mul_mm_iq2_xs_f32);
-        GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32);
-        GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32);
-        GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32);
-        GGML_METAL_DEL_KERNEL(mul_mm_id_q4_1_f32);
-        GGML_METAL_DEL_KERNEL(mul_mm_id_q5_0_f32);
-        GGML_METAL_DEL_KERNEL(mul_mm_id_q5_1_f32);
-        GGML_METAL_DEL_KERNEL(mul_mm_id_q8_0_f32);
-        GGML_METAL_DEL_KERNEL(mul_mm_id_q2_K_f32);
-        GGML_METAL_DEL_KERNEL(mul_mm_id_q3_K_f32);
-        GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32);
-        GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32);
-        GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32);
-        GGML_METAL_DEL_KERNEL(mul_mm_id_iq2_xxs_f32);
-        GGML_METAL_DEL_KERNEL(mul_mm_id_iq2_xs_f32);
-    }
-    GGML_METAL_DEL_KERNEL(rope_f32);
-    GGML_METAL_DEL_KERNEL(rope_f16);
-    GGML_METAL_DEL_KERNEL(alibi_f32);
-    GGML_METAL_DEL_KERNEL(im2col_f16);
-    GGML_METAL_DEL_KERNEL(upscale_f32);
-    GGML_METAL_DEL_KERNEL(pad_f32);
-    GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
-    GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
-    GGML_METAL_DEL_KERNEL(leaky_relu_f32);
-    GGML_METAL_DEL_KERNEL(cpy_f32_f16);
-    GGML_METAL_DEL_KERNEL(cpy_f32_f32);
-    GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
-    GGML_METAL_DEL_KERNEL(cpy_f32_q4_0);
-    GGML_METAL_DEL_KERNEL(cpy_f32_q4_1);
-    //GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
-    //GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
-    GGML_METAL_DEL_KERNEL(cpy_f16_f16);
-    GGML_METAL_DEL_KERNEL(cpy_f16_f32);
-    GGML_METAL_DEL_KERNEL(concat);
-    GGML_METAL_DEL_KERNEL(sqr);
-    GGML_METAL_DEL_KERNEL(sum_rows);
-
-#undef GGML_METAL_DEL_KERNEL
 
     for (int i = 0; i < ctx->n_buffers; ++i) {
         [ctx->buffers[i].metal release];
     }
 
+    for (int i = 0; i < GGML_METAL_MAX_KERNELS; ++i) {
+        if (ctx->kernels[i].pipeline) {
+            [ctx->kernels[i].pipeline release];
+        }
+
+        if (ctx->kernels[i].function) {
+            [ctx->kernels[i].function release];
+        }
+    }
+
     [ctx->library release];
     [ctx->queue release];
     [ctx->device release];
@@ -930,7 +862,7 @@ void ggml_metal_graph_find_concurrency(
     }
 }
 
-static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
+static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const struct ggml_tensor * op) {
     switch (op->op) {
         case GGML_OP_UNARY:
             switch (ggml_get_unary_op(op)) {
@@ -956,9 +888,11 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
         case GGML_OP_SCALE:
         case GGML_OP_SQR:
         case GGML_OP_SUM_ROWS:
+            return true;
         case GGML_OP_SOFT_MAX:
         case GGML_OP_RMS_NORM:
         case GGML_OP_GROUP_NORM:
+            return ctx->support_simdgroup_reduction;
         case GGML_OP_NORM:
         case GGML_OP_ALIBI:
         case GGML_OP_ROPE:
@@ -967,9 +901,10 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
         case GGML_OP_PAD:
         case GGML_OP_ARGSORT:
         case GGML_OP_LEAKY_RELU:
+            return true;
         case GGML_OP_MUL_MAT:
         case GGML_OP_MUL_MAT_ID:
-            return true;
+            return ctx->support_simdgroup_reduction;
         case GGML_OP_CPY:
         case GGML_OP_DUP:
         case GGML_OP_CONT:
@@ -1007,6 +942,7 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
             return false;
     }
 }
+
 bool ggml_metal_graph_compute(
         struct ggml_metal_context * ctx,
                struct ggml_cgraph * gf) {
@@ -1077,7 +1013,7 @@ bool ggml_metal_graph_compute(
                         } break;
                 }
 
-                if (!ggml_metal_supports_op(dst)) {
+                if (!ggml_metal_supports_op(ctx, dst)) {
                     GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
                     GGML_ASSERT(!"unsupported op");
                 }
@@ -1143,7 +1079,9 @@ bool ggml_metal_graph_compute(
                         {
                             const int64_t nb = ne00;
 
-                            [encoder setComputePipelineState:ctx->pipeline_concat];
+                            id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
+
+                            [encoder setComputePipelineState:pipeline];
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
@@ -1197,18 +1135,18 @@ bool ggml_metal_graph_compute(
 
                                 nb = ne00 / 4;
                                 switch (dst->op) {
-                                    case GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break;
-                                    case GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break;
-                                    case GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break;
+                                    case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break;
+                                    case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW].pipeline; break;
+                                    case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW].pipeline; break;
                                     default: GGML_ASSERT(false);
                                 }
 
                                 bcast_row = true;
                             } else {
                                 switch (dst->op) {
-                                    case GGML_OP_ADD: pipeline = ctx->pipeline_add; break;
-                                    case GGML_OP_MUL: pipeline = ctx->pipeline_mul; break;
-                                    case GGML_OP_DIV: pipeline = ctx->pipeline_div; break;
+                                    case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; break;
+                                    case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
+                                    case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
                                     default: GGML_ASSERT(false);
                                 }
                             }
@@ -1275,9 +1213,9 @@ bool ggml_metal_graph_compute(
                                 // not sure how to avoid this
                                 // TODO: make a simpler cpy_bytes kernel
 
-                                const int nth = MIN((int) ctx->pipeline_cpy_f32_f32.maxTotalThreadsPerThreadgroup, ne00);
+                                const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;
 
-                                [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
+                                [encoder setComputePipelineState:pipeline];
                                 [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
                                 [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
                                 [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
@@ -1297,10 +1235,14 @@ bool ggml_metal_graph_compute(
                                 [encoder setBytes:&nb2     length:sizeof(uint64_t) atIndex:16];
                                 [encoder setBytes:&nb3     length:sizeof(uint64_t) atIndex:17];
 
+                                const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
+
                                 [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                             }
 
-                            [encoder setComputePipelineState:ctx->pipeline_add];
+                            const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
+
+                            [encoder setComputePipelineState:pipeline];
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
@@ -1330,7 +1272,7 @@ bool ggml_metal_graph_compute(
                             [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
                             [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
 
-                            const int nth = MIN((int) ctx->pipeline_add.maxTotalThreadsPerThreadgroup, ne00);
+                            const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
 
                             [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
@@ -1342,13 +1284,16 @@ bool ggml_metal_graph_compute(
 
                             int64_t n = ggml_nelements(dst);
 
+                            id<MTLComputePipelineState> pipeline = nil;
+
                             if (n % 4 == 0) {
                                 n /= 4;
-                                [encoder setComputePipelineState:ctx->pipeline_scale_4];
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE_4].pipeline;
                             } else {
-                                [encoder setComputePipelineState:ctx->pipeline_scale];
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SCALE].pipeline;
                             }
 
+                            [encoder setComputePipelineState:pipeline];
                             [encoder setBuffer:id_src0   offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_dst    offset:offs_dst  atIndex:1];
                             [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
@@ -1359,7 +1304,9 @@ bool ggml_metal_graph_compute(
                         switch (ggml_get_unary_op(gf->nodes[i])) {
                             case GGML_UNARY_OP_TANH:
                                 {
-                                    [encoder setComputePipelineState:ctx->pipeline_tanh];
+                                    id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TANH].pipeline;
+
+                                    [encoder setComputePipelineState:pipeline];
                                     [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                                     [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
 
@@ -1369,7 +1316,9 @@ bool ggml_metal_graph_compute(
                                 } break;
                             case GGML_UNARY_OP_RELU:
                                 {
-                                    [encoder setComputePipelineState:ctx->pipeline_relu];
+                                    id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RELU].pipeline;
+
+                                    [encoder setComputePipelineState:pipeline];
                                     [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                                     [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
 
@@ -1379,7 +1328,9 @@ bool ggml_metal_graph_compute(
                                 } break;
                             case GGML_UNARY_OP_GELU:
                                 {
-                                    [encoder setComputePipelineState:ctx->pipeline_gelu];
+                                    id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU].pipeline;
+
+                                    [encoder setComputePipelineState:pipeline];
                                     [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                                     [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
 
@@ -1390,7 +1341,9 @@ bool ggml_metal_graph_compute(
                                 } break;
                             case GGML_UNARY_OP_GELU_QUICK:
                                 {
-                                    [encoder setComputePipelineState:ctx->pipeline_gelu_quick];
+                                    id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GELU_QUICK].pipeline;
+
+                                    [encoder setComputePipelineState:pipeline];
                                     [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                                     [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
 
@@ -1401,7 +1354,9 @@ bool ggml_metal_graph_compute(
                                 } break;
                             case GGML_UNARY_OP_SILU:
                                 {
-                                    [encoder setComputePipelineState:ctx->pipeline_silu];
+                                    id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SILU].pipeline;
+
+                                    [encoder setComputePipelineState:pipeline];
                                     [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                                     [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
 
@@ -1420,18 +1375,23 @@ bool ggml_metal_graph_compute(
                         {
                             GGML_ASSERT(ggml_is_contiguous(src0));
 
-                            [encoder setComputePipelineState:ctx->pipeline_sqr];
+                            id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SQR].pipeline;
+
+                            [encoder setComputePipelineState:pipeline];
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_dst  offset:offs_dst atIndex:1];
 
                             const int64_t n = ggml_nelements(dst);
+
                             [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                         } break;
                     case GGML_OP_SUM_ROWS:
                         {
                             GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
 
-                            [encoder setComputePipelineState:ctx->pipeline_sum_rows];
+                            id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline;
+
+                            [encoder setComputePipelineState:pipeline];
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
                             [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
@@ -1465,20 +1425,23 @@ bool ggml_metal_graph_compute(
                         {
                             int nth = 32; // SIMD width
 
+                            id<MTLComputePipelineState> pipeline = nil;
+
                             if (ne00%4 == 0) {
                                 while (nth < ne00/4 && nth < 256) {
                                     nth *= 2;
                                 }
-                                [encoder setComputePipelineState:ctx->pipeline_soft_max_4];
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX_4].pipeline;
                             } else {
                                 while (nth < ne00 && nth < 1024) {
                                     nth *= 2;
                                 }
-                                [encoder setComputePipelineState:ctx->pipeline_soft_max];
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
                             }
 
                             const float scale = ((float *) dst->op_params)[0];
 
+                            [encoder setComputePipelineState:pipeline];
                             [encoder setBuffer:id_src0 offset:offs_src0   atIndex:0];
                             if (id_src1) {
                                 [encoder setBuffer:id_src1 offset:offs_src1   atIndex:1];
@@ -1498,11 +1461,15 @@ bool ggml_metal_graph_compute(
                         {
                             const int n_past = ((int32_t *)(dst->op_params))[0];
 
+                            id<MTLComputePipelineState> pipeline = nil;
+
                             if (ne00%8 == 0) {
-                                [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf_8];
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8].pipeline;
                             } else {
-                                [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline;
                             }
+
+                            [encoder setComputePipelineState:pipeline];
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
                             [encoder setBytes:&ne00   length:sizeof(ne00) atIndex:2];
@@ -1562,23 +1529,28 @@ bool ggml_metal_graph_compute(
                                 ne00 % 32 == 0 && ne00 >= 64 &&
                                 (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) {
                                 //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
+
+                                id<MTLComputePipelineState> pipeline = nil;
+
                                 switch (src0->type) {
-                                    case GGML_TYPE_F32:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32];  break;
-                                    case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32];  break;
-                                    case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
-                                    case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
-                                    case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_0_f32]; break;
-                                    case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_1_f32]; break;
-                                    case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
-                                    case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
-                                    case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
-                                    case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
-                                    case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
-                                    case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
-                                    case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_iq2_xxs_f32]; break;
-                                    case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_mul_mm_iq2_xs_f32]; break;
+                                    case GGML_TYPE_F32:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32    ].pipeline; break;
+                                    case GGML_TYPE_F16:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32    ].pipeline; break;
+                                    case GGML_TYPE_Q4_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32   ].pipeline; break;
+                                    case GGML_TYPE_Q4_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32   ].pipeline; break;
+                                    case GGML_TYPE_Q5_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32   ].pipeline; break;
+                                    case GGML_TYPE_Q5_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32   ].pipeline; break;
+                                    case GGML_TYPE_Q8_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32   ].pipeline; break;
+                                    case GGML_TYPE_Q2_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32   ].pipeline; break;
+                                    case GGML_TYPE_Q3_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32   ].pipeline; break;
+                                    case GGML_TYPE_Q4_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32   ].pipeline; break;
+                                    case GGML_TYPE_Q5_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32   ].pipeline; break;
+                                    case GGML_TYPE_Q6_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32   ].pipeline; break;
+                                    case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break;
+                                    case GGML_TYPE_IQ2_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break;
                                     default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
                                 }
+
+                                [encoder setComputePipelineState:pipeline];
                                 [encoder setBuffer:id_src0 offset:offs_src0    atIndex:0];
                                 [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1];
                                 [encoder setBuffer:id_dst  offset:offs_dst     atIndex:2];
@@ -1602,12 +1574,14 @@ bool ggml_metal_graph_compute(
                                 int nrows = 1;
                                 //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
 
+                                id<MTLComputePipelineState> pipeline = nil;
+
                                 // use custom matrix x vector kernel
                                 switch (src0t) {
                                     case GGML_TYPE_F32:
                                         {
                                             GGML_ASSERT(src1t == GGML_TYPE_F32);
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
+                                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline;
                                             nrows = 4;
                                         } break;
                                     case GGML_TYPE_F16:
@@ -1616,16 +1590,16 @@ bool ggml_metal_graph_compute(
                                             nth1 = 1;
                                             if (src1t == GGML_TYPE_F32) {
                                                 if (ne11 * ne12 < 4) {
-                                                    [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
+                                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline;
                                                 } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
-                                                    [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
+                                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline;
                                                     nrows = ne11;
                                                 } else {
-                                                    [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
+                                                    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline;
                                                     nrows = 4;
                                                 }
                                             } else {
-                                                [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f16];
+                                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline;
                                                 nrows = 4;
                                             }
                                         } break;
@@ -1633,73 +1607,73 @@ bool ggml_metal_graph_compute(
                                         {
                                             nth0 = 8;
                                             nth1 = 8;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
+                                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline;
                                         } break;
                                     case GGML_TYPE_Q4_1:
                                         {
                                             nth0 = 8;
                                             nth1 = 8;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
+                                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline;
                                         } break;
                                     case GGML_TYPE_Q5_0:
                                         {
                                             nth0 = 8;
                                             nth1 = 8;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
+                                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline;
                                         } break;
                                     case GGML_TYPE_Q5_1:
                                         {
                                             nth0 = 8;
                                             nth1 = 8;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
+                                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline;
                                         } break;
                                     case GGML_TYPE_Q8_0:
                                         {
                                             nth0 = 8;
                                             nth1 = 8;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
+                                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
                                         } break;
                                     case GGML_TYPE_Q2_K:
                                         {
                                             nth0 = 2;
                                             nth1 = 32;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
+                                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline;
                                         } break;
                                     case GGML_TYPE_Q3_K:
                                         {
                                             nth0 = 2;
                                             nth1 = 32;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
+                                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline;
                                         } break;
                                     case GGML_TYPE_Q4_K:
                                         {
                                             nth0 = 4; //1;
                                             nth1 = 8; //32;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
+                                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline;
                                         } break;
                                     case GGML_TYPE_Q5_K:
                                         {
                                             nth0 = 2;
                                             nth1 = 32;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
+                                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline;
                                         } break;
                                     case GGML_TYPE_Q6_K:
                                         {
                                             nth0 = 2;
                                             nth1 = 32;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
+                                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline;
                                         } break;
                                     case GGML_TYPE_IQ2_XXS:
                                         {
                                             nth0 = 4;
                                             nth1 = 16;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_iq2_xxs_f32];
+                                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline;
                                         } break;
                                     case GGML_TYPE_IQ2_XS:
                                         {
                                             nth0 = 4;
                                             nth1 = 16;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_iq2_xs_f32];
+                                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline;
                                         } break;
                                     default:
                                         {
@@ -1712,6 +1686,7 @@ bool ggml_metal_graph_compute(
                                     GGML_ASSERT(ne00 >= nth0*nth1);
                                 }
 
+                                [encoder setComputePipelineState:pipeline];
                                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                                 [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
                                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
@@ -1818,23 +1793,28 @@ bool ggml_metal_graph_compute(
                             if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
                                 ne20 % 32 == 0 && ne20 >= 64 &&
                                 ne11 > ne11_mm_min) {
+
+                                id<MTLComputePipelineState> pipeline = nil;
+
                                 switch (src2->type) {
-                                    case GGML_TYPE_F32:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32];  break;
-                                    case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32];  break;
-                                    case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_0_f32]; break;
-                                    case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_1_f32]; break;
-                                    case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_0_f32]; break;
-                                    case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_1_f32]; break;
-                                    case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q8_0_f32]; break;
-                                    case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q2_K_f32]; break;
-                                    case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q3_K_f32]; break;
-                                    case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break;
-                                    case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break;
-                                    case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break;
-                                    case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_iq2_xxs_f32]; break;
-                                    case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_iq2_xs_f32]; break;
+                                    case GGML_TYPE_F32:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32    ].pipeline; break;
+                                    case GGML_TYPE_F16:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32    ].pipeline; break;
+                                    case GGML_TYPE_Q4_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32   ].pipeline; break;
+                                    case GGML_TYPE_Q4_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32   ].pipeline; break;
+                                    case GGML_TYPE_Q5_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32   ].pipeline; break;
+                                    case GGML_TYPE_Q5_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32   ].pipeline; break;
+                                    case GGML_TYPE_Q8_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32   ].pipeline; break;
+                                    case GGML_TYPE_Q2_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32   ].pipeline; break;
+                                    case GGML_TYPE_Q3_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32   ].pipeline; break;
+                                    case GGML_TYPE_Q4_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32   ].pipeline; break;
+                                    case GGML_TYPE_Q5_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32   ].pipeline; break;
+                                    case GGML_TYPE_Q6_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32   ].pipeline; break;
+                                    case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32].pipeline; break;
+                                    case GGML_TYPE_IQ2_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline; break;
                                     default: GGML_ASSERT(false && "MUL_MAT_ID not implemented");
                                 }
+
+                                [encoder setComputePipelineState:pipeline];
                                 [encoder setBuffer:id_src0 offset:offs_src0    atIndex:0];
                                 [encoder setBuffer:id_src1 offset:offs_src1    atIndex:1];
                                 [encoder setBuffer:id_dst  offset:offs_dst     atIndex:2];
@@ -1874,91 +1854,93 @@ bool ggml_metal_graph_compute(
                                 int nrows = 1;
                                 //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
 
+                                id<MTLComputePipelineState> pipeline = nil;
+
                                 // use custom matrix x vector kernel
                                 switch (src2t) {
                                     case GGML_TYPE_F32:
                                         {
                                             GGML_ASSERT(src1t == GGML_TYPE_F32);
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32];
+                                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline;
                                         } break;
                                     case GGML_TYPE_F16:
                                         {
                                             GGML_ASSERT(src1t == GGML_TYPE_F32);
                                             nth0 = 32;
                                             nth1 = 1;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32];
+                                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline;
                                         } break;
                                     case GGML_TYPE_Q4_0:
                                         {
                                             nth0 = 8;
                                             nth1 = 8;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32];
+                                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline;
                                         } break;
                                     case GGML_TYPE_Q4_1:
                                         {
                                             nth0 = 8;
                                             nth1 = 8;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32];
+                                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline;
                                         } break;
                                     case GGML_TYPE_Q5_0:
                                         {
                                             nth0 = 8;
                                             nth1 = 8;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32];
+                                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline;
                                         } break;
                                     case GGML_TYPE_Q5_1:
                                         {
                                             nth0 = 8;
                                             nth1 = 8;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32];
+                                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline;
                                         } break;
                                     case GGML_TYPE_Q8_0:
                                         {
                                             nth0 = 8;
                                             nth1 = 8;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32];
+                                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
                                         } break;
                                     case GGML_TYPE_Q2_K:
                                         {
                                             nth0 = 2;
                                             nth1 = 32;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32];
+                                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline;
                                         } break;
                                     case GGML_TYPE_Q3_K:
                                         {
                                             nth0 = 2;
                                             nth1 = 32;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32];
+                                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline;
                                         } break;
                                     case GGML_TYPE_Q4_K:
                                         {
                                             nth0 = 4; //1;
                                             nth1 = 8; //32;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32];
+                                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline;
                                         } break;
                                     case GGML_TYPE_Q5_K:
                                         {
                                             nth0 = 2;
                                             nth1 = 32;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32];
+                                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline;
                                         } break;
                                     case GGML_TYPE_Q6_K:
                                         {
                                             nth0 = 2;
                                             nth1 = 32;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
+                                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline;
                                         } break;
                                     case GGML_TYPE_IQ2_XXS:
                                         {
                                             nth0 = 4;
                                             nth1 = 16;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_iq2_xxs_f32];
+                                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline;
                                         } break;
                                     case GGML_TYPE_IQ2_XS:
                                         {
                                             nth0 = 4;
                                             nth1 = 16;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_iq2_xs_f32];
+                                            pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline;
                                         } break;
                                     default:
                                         {
@@ -1973,6 +1955,7 @@ bool ggml_metal_graph_compute(
 
                                 const int64_t _ne1 = 1; // kernels needs a reference in constant memory
 
+                                [encoder setComputePipelineState:pipeline];
                                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                                 [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
                                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
@@ -2040,25 +2023,28 @@ bool ggml_metal_graph_compute(
                         } break;
                     case GGML_OP_GET_ROWS:
                         {
+                            id<MTLComputePipelineState> pipeline = nil;
+
                             switch (src0->type) {
-                                case GGML_TYPE_F32:  [encoder setComputePipelineState:ctx->pipeline_get_rows_f32];  break;
-                                case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_get_rows_f16];  break;
-                                case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
-                                case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
-                                case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_0]; break;
-                                case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_1]; break;
-                                case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
-                                case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
-                                case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
-                                case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_K]; break;
-                                case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_K]; break;
-                                case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_K]; break;
-                                case GGML_TYPE_I32:  [encoder setComputePipelineState:ctx->pipeline_get_rows_i32];  break;
-                                case GGML_TYPE_IQ2_XXS: [encoder setComputePipelineState:ctx->pipeline_get_rows_iq2_xxs]; break;
-                                case GGML_TYPE_IQ2_XS : [encoder setComputePipelineState:ctx->pipeline_get_rows_iq2_xs]; break;
+                                case GGML_TYPE_F32:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32    ].pipeline; break;
+                                case GGML_TYPE_F16:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16    ].pipeline; break;
+                                case GGML_TYPE_Q4_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0   ].pipeline; break;
+                                case GGML_TYPE_Q4_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1   ].pipeline; break;
+                                case GGML_TYPE_Q5_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0   ].pipeline; break;
+                                case GGML_TYPE_Q5_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1   ].pipeline; break;
+                                case GGML_TYPE_Q8_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0   ].pipeline; break;
+                                case GGML_TYPE_Q2_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K   ].pipeline; break;
+                                case GGML_TYPE_Q3_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K   ].pipeline; break;
+                                case GGML_TYPE_Q4_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K   ].pipeline; break;
+                                case GGML_TYPE_Q5_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K   ].pipeline; break;
+                                case GGML_TYPE_Q6_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K   ].pipeline; break;
+                                case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS].pipeline; break;
+                                case GGML_TYPE_IQ2_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline; break;
+                                case GGML_TYPE_I32:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_I32    ].pipeline; break;
                                 default: GGML_ASSERT(false && "not implemented");
                             }
 
+                            [encoder setComputePipelineState:pipeline];
                             [encoder setBuffer:id_src0     offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_src1     offset:offs_src1 atIndex:1];
                             [encoder setBuffer:id_dst      offset:offs_dst  atIndex:2];
@@ -2086,7 +2072,9 @@ bool ggml_metal_graph_compute(
                                 nth *= 2;
                             }
 
-                            [encoder setComputePipelineState:ctx->pipeline_rms_norm];
+                            id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline;
+
+                            [encoder setComputePipelineState:pipeline];
                             [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
                             [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
                             [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
@@ -2115,7 +2103,9 @@ bool ggml_metal_graph_compute(
                             //    nth *= 2;
                             //}
 
-                            [encoder setComputePipelineState:ctx->pipeline_group_norm];
+                            id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline;
+
+                            [encoder setComputePipelineState:pipeline];
                             [encoder setBuffer:id_src0  offset:offs_src0        atIndex:0];
                             [encoder setBuffer:id_dst   offset:offs_dst         atIndex:1];
                             [encoder setBytes:&ne00     length:sizeof( int64_t) atIndex:2];
@@ -2137,7 +2127,9 @@ bool ggml_metal_graph_compute(
 
                             const int nth = MIN(256, ne00);
 
-                            [encoder setComputePipelineState:ctx->pipeline_norm];
+                            id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline;
+
+                            [encoder setComputePipelineState:pipeline];
                             [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
                             [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
                             [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
@@ -2164,7 +2156,9 @@ bool ggml_metal_graph_compute(
                             const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
                             const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
 
-                            [encoder setComputePipelineState:ctx->pipeline_alibi_f32];
+                            id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline;
+
+                            [encoder setComputePipelineState:pipeline];
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
                             [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
@@ -2209,12 +2203,15 @@ bool ggml_metal_graph_compute(
                             memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
                             memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
 
+                            id<MTLComputePipelineState> pipeline = nil;
+
                             switch (src0->type) {
-                                case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break;
-                                case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_rope_f16]; break;
+                                case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F32].pipeline; break;
+                                case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_F16].pipeline; break;
                                 default: GGML_ASSERT(false);
                             };
 
+                            [encoder setComputePipelineState:pipeline];
                             [encoder setBuffer:id_src0     offset:offs_src0        atIndex:0];
                             [encoder setBuffer:id_src1     offset:offs_src1        atIndex:1];
                             [encoder setBuffer:id_dst      offset:offs_dst         atIndex:2];
@@ -2277,12 +2274,15 @@ bool ggml_metal_graph_compute(
                             const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
                             const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
 
+                            id<MTLComputePipelineState> pipeline = nil;
+
                             switch (src0->type) {
                                 case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
-                                case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_im2col_f16]; break;
+                                case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
                                 default: GGML_ASSERT(false);
                             };
 
+                            [encoder setComputePipelineState:pipeline];
                             [encoder setBuffer:id_src1 offset:offs_src1        atIndex:0];
                             [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
                             [encoder setBytes:&ofs0    length:sizeof( int32_t) atIndex:2];
@@ -2305,7 +2305,9 @@ bool ggml_metal_graph_compute(
 
                             const int sf = dst->op_params[0];
 
-                            [encoder setComputePipelineState:ctx->pipeline_upscale_f32];
+                            const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline;
+
+                            [encoder setComputePipelineState:pipeline];
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
                             [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
@@ -2326,7 +2328,7 @@ bool ggml_metal_graph_compute(
                             [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:17];
                             [encoder setBytes:&sf   length:sizeof(sf)   atIndex:18];
 
-                            const int nth = MIN((int) ctx->pipeline_upscale_f32.maxTotalThreadsPerThreadgroup, ne0);
+                            const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
 
                             [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
@@ -2334,7 +2336,9 @@ bool ggml_metal_graph_compute(
                         {
                             GGML_ASSERT(src0->type == GGML_TYPE_F32);
 
-                            [encoder setComputePipelineState:ctx->pipeline_pad_f32];
+                            id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline;
+
+                            [encoder setComputePipelineState:pipeline];
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
                             [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
@@ -2367,12 +2371,15 @@ bool ggml_metal_graph_compute(
 
                             enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
 
+                            id<MTLComputePipelineState> pipeline = nil;
+
                             switch (order) {
-                                case GGML_SORT_ASC:  [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_asc];  break;
-                                case GGML_SORT_DESC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_desc]; break;
+                                case GGML_SORT_ASC:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC].pipeline;  break;
+                                case GGML_SORT_DESC: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC].pipeline; break;
                                 default: GGML_ASSERT(false);
                             };
 
+                            [encoder setComputePipelineState:pipeline];
                             [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
                             [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
                             [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
@@ -2386,7 +2393,9 @@ bool ggml_metal_graph_compute(
                             float slope;
                             memcpy(&slope, dst->op_params, sizeof(float));
 
-                            [encoder setComputePipelineState:ctx->pipeline_leaky_relu_f32];
+                            id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline;
+
+                            [encoder setComputePipelineState:pipeline];
                             [encoder setBuffer:id_src0 offset:offs_src0   atIndex:0];
                             [encoder setBuffer:id_dst  offset:offs_dst    atIndex:1];
                             [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
@@ -2403,33 +2412,36 @@ bool ggml_metal_graph_compute(
 
                             int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
 
+                            id<MTLComputePipelineState> pipeline = nil;
+
                             switch (src0t) {
                                 case GGML_TYPE_F32:
                                     {
                                         GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
 
                                         switch (dstt) {
-                                            case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16];  break;
-                                            case GGML_TYPE_F32:  [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];  break;
-                                            case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q8_0]; break;
-                                            case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_0]; break;
-                                            case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q4_1]; break;
-                                            //case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_0]; break;
-                                            //case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_q5_1]; break;
+                                            case GGML_TYPE_F16:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline;  break;
+                                            case GGML_TYPE_F32:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline;  break;
+                                            case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
+                                            case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
+                                            case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
+                                          //case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0].pipeline; break;
+                                          //case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1].pipeline; break;
                                             default: GGML_ASSERT(false && "not implemented");
                                         };
                                     } break;
                                 case GGML_TYPE_F16:
                                     {
                                         switch (dstt) {
-                                            case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
-                                            case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break;
+                                            case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
+                                            case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
                                             default: GGML_ASSERT(false && "not implemented");
                                         };
                                     } break;
                                 default: GGML_ASSERT(false && "not implemented");
                             }
 
+                            [encoder setComputePipelineState:pipeline];
                             [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
                             [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
                             [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
@@ -2794,9 +2806,9 @@ static bool ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml
 }
 
 static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
-    return ggml_metal_supports_op(op);
+    struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;
 
-    UNUSED(backend);
+    return ggml_metal_supports_op(metal_ctx, op);
 }
 
 static struct ggml_backend_i ggml_backend_metal_i = {