]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : Add template specialization for mul_mm_id w/ ne20 == 10 (#15799)
authorGabe Goodhart <redacted>
Thu, 4 Sep 2025 15:53:22 +0000 (09:53 -0600)
committerGitHub <redacted>
Thu, 4 Sep 2025 15:53:22 +0000 (18:53 +0300)
Branch: GGMLMetalNE20

Signed-off-by: Gabe Goodhart <redacted>
ggml/src/ggml-metal/ggml-metal.m
ggml/src/ggml-metal/ggml-metal.metal

index 9b4006d987c3b9558139964ac855f6a51ec66642..c1a0a2bef171e5fa169b4d4bbfcabaa259379815 100644 (file)
@@ -407,6 +407,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_10,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,
@@ -1439,6 +1440,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4,       mul_mm_id_map0_f16_ne20_4,       has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6,       mul_mm_id_map0_f16_ne20_6,       has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8,       mul_mm_id_map0_f16_ne20_8,       has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_10,      mul_mm_id_map0_f16_ne20_10,      has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16,      mul_mm_id_map0_f16_ne20_16,      has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,               mul_mm_id_f32_f16,               has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,               mul_mm_id_f16_f16,               has_simdgroup_mm);
@@ -3979,6 +3981,7 @@ static int ggml_metal_encode_node(
                             case 4:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4 ].pipeline; break;
                             case 6:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6 ].pipeline; break;
                             case 8:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8 ].pipeline; break;
+                            case 10: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_10].pipeline; break;
                             case 16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16].pipeline; break;
                             default: GGML_ABORT("missing specialization for ne20 = %d", (int) ne20);
                         }
index 9c5933d24a0e3140cc25f054af8953c6511bb053..2d56c62674c8e049b491f90c914a9491bd54a9dc 100644 (file)
@@ -7618,6 +7618,7 @@ template [[host_name("kernel_mul_mm_id_map0_f16_ne20_2" )]] kernel kernel_mul_mm
 template [[host_name("kernel_mul_mm_id_map0_f16_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
 template [[host_name("kernel_mul_mm_id_map0_f16_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
 template [[host_name("kernel_mul_mm_id_map0_f16_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
+template [[host_name("kernel_mul_mm_id_map0_f16_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
 template [[host_name("kernel_mul_mm_id_map0_f16_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
 
 template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>