]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
metal : fix l2 norm scale (llama/20493)
authorGeorgi Gerganov <redacted>
Fri, 13 Mar 2026 09:43:20 +0000 (11:43 +0200)
committerGeorgi Gerganov <redacted>
Mon, 16 Mar 2026 11:10:15 +0000 (13:10 +0200)
ggml/src/ggml-metal/ggml-metal-device.m
ggml/src/ggml-metal/ggml-metal.metal

index 05b826a61b833f217d00b8f0ac803fe3226305d6..b7d587f3bd99e38ab65a37f60434f705a0d45711 100644 (file)
@@ -1156,7 +1156,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
         case GGML_OP_RWKV_WKV7:
             return true;
         case GGML_OP_GATED_DELTA_NET:
-            return op->src[2]->ne[0] % 32 == 0;
+            return has_simdgroup_reduction && op->src[2]->ne[0] % 32 == 0;
         case GGML_OP_SOLVE_TRI:
         case GGML_OP_MUL_MAT:
         case GGML_OP_MUL_MAT_ID:
index 24a3092af22c65fac7e5b87606d2e40132dbc399..107e7cf2ff36bc18ac69a4cd8dede47621b21453 100644 (file)
@@ -3006,7 +3006,7 @@ kernel void kernel_l2_norm_impl(
     sumf = shmem_f32[tiisg];
     sumf = simd_sum(sumf);
 
-    const float scale = 1.0f/sqrt(max(sumf, args.eps));
+    const float scale = 1.0f/max(sqrt(sumf), args.eps);
 
     for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
         y[i00] = x[i00] * scale;