]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : fix kernel requirements (#15983)
authorGeorgi Gerganov <redacted>
Sun, 14 Sep 2025 12:33:22 +0000 (15:33 +0300)
committerGitHub <redacted>
Sun, 14 Sep 2025 12:33:22 +0000 (15:33 +0300)
* metal : fix kernel requirements

ggml-ci

* cont : fix supports_op

* cont : fix supports_op for ARGMAX

ggml/src/ggml-metal/ggml-metal.m

index 82d8077a0f0d8a479398a84ddfda28326163a814..13f9de297eae480d05f9cefdcbc5330177093315 100644 (file)
@@ -1219,10 +1219,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,                 set_rows_iq4_nl,                 true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM,                         l2_norm,                         has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM,                      group_norm,                      has_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM,                            norm,                            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM,                            norm,                            has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,                    ssm_conv_f32,                    true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,                    ssm_scan_f32,                    true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP,              ssm_scan_f32_group,              true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,                    ssm_scan_f32,                    has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP,              ssm_scan_f32_group,              has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,                   rwkv_wkv6_f32,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,                   rwkv_wkv7_f32,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,                  mul_mv_f32_f32,                  has_simdgroup_reduction);
@@ -1443,9 +1443,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU_OAI,                      swiglu_oai,                      true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF,                       geglu_erf,                       true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK,                     geglu_quick,                     true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                        sum_rows,                        true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN,                            mean,                            true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX,                          argmax,                          true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                        sum_rows,                        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN,                            mean,                            has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX,                          argmax,                          has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,                 pool_2d_avg_f32,                 true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,                 pool_2d_max_f32,                 true);
     }
@@ -1982,7 +1982,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
         case GGML_OP_L2_NORM:
             return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
         case GGML_OP_ARGMAX:
-            return true;
+            return has_simdgroup_reduction;
         case GGML_OP_NORM:
             return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
         case GGML_OP_ROPE:
@@ -2028,6 +2028,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
             return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
         case GGML_OP_SSM_CONV:
         case GGML_OP_SSM_SCAN:
+            return has_simdgroup_reduction;
         case GGML_OP_RWKV_WKV6:
         case GGML_OP_RWKV_WKV7:
             return true;