]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : correct fix of kernel_norm (#3060)
authorKawrakow <redacted>
Thu, 7 Sep 2023 13:42:42 +0000 (15:42 +0200)
committerGitHub <redacted>
Thu, 7 Sep 2023 13:42:42 +0000 (16:42 +0300)
Co-authored-by: Iwan Kawrakow <redacted>
Co-authored-by: Georgi Gerganov <redacted>
ggml-metal.metal

index d66ff340adec2cd9ef19576c1fb1f56e41de4ac6..5edf6d521c2d3ed52175b005adcef8923af4124e 100644 (file)
@@ -220,29 +220,14 @@ kernel void kernel_norm(
         }
         threadgroup_barrier(mem_flags::mem_threadgroup);
     }
-    // broadcast
-    if (tpitg == 0) {
-        sum[0] /= ne00;
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    const float mean  = sum[0];
+    const float mean  = sum[0] / ne00;
 
-    // recenter
+    // recenter and VARIANCE
+    threadgroup_barrier(mem_flags::mem_threadgroup);
     device float * y = dst + tgpig*ne00;
-    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
-        y[i00] = x[i00] - mean;
-    }
-
-    // VARIANCE
-    // parallel sum
-    //
-    // WARNING: combining this loop with the one above will give you wrong results for nth == 256
-    //          I have no idea why, so for now I am keeping them separate. But this behavior is very concerning.
-    //          Tested with:
-    //          ./perplexity -m ./falcon-7b/ggml-model-q4_0.gguf -f wiki.test.raw -ngl 1 -t 4
-    //
     sum[tpitg] = 0.0f;
     for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        y[i00] = x[i00] - mean;
         sum[tpitg] += y[i00] * y[i00];
     }
 
@@ -254,12 +239,7 @@ kernel void kernel_norm(
         }
         threadgroup_barrier(mem_flags::mem_threadgroup);
     }
-    // broadcast
-    if (tpitg == 0) {
-        sum[0] /= ne00;
-    }
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    const float variance = sum[0];
+    const float variance = sum[0] / ne00;
 
     const float scale = 1.0f/sqrt(variance + eps);
     for (int i00 = tpitg; i00 < ne00; i00 += ntg) {