]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: fix rms_norm_mul to handle broadcasting dim0 (#14817)
authorJeff Bolz <redacted>
Tue, 22 Jul 2025 15:35:21 +0000 (10:35 -0500)
committerGitHub <redacted>
Tue, 22 Jul 2025 15:35:21 +0000 (17:35 +0200)
ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp

index c3f1369b66315eaef2a7ca656f35453ba1efcc68..1a7a381ce59216796929bd60f4806f7316f72b36 100644 (file)
@@ -10248,7 +10248,7 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st
         }
         // if rms_norm is the B operand, then we don't handle broadcast
         if (rms_norm == mul->src[1] &&
-            mul->src[0]->ne[1] != rms_norm->ne[1]) {
+            !ggml_are_same_shape(mul->src[0], rms_norm)) {
             return false;
         }
         // rms_norm shader assumes contiguous rows
index 6428ca7ba330098cc916e1b74733b22608edf634..bdd7db2d6987a7d6b46feb3b5c9d611cc4e41e16 100644 (file)
@@ -50,8 +50,14 @@ void main() {
     const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
 
     if (do_multiply) {
-        [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
-            data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
+        if (ncols > p.ne10) {
+            [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
+                data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)]));
+            }
+        } else {
+            [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
+                data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
+            }
         }
     } else {
         [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {