]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : fix kernel_norm (fixes Falcon on Metal) (#3057)
authorGeorgi Gerganov <redacted>
Thu, 7 Sep 2023 12:49:09 +0000 (15:49 +0300)
committerGitHub <redacted>
Thu, 7 Sep 2023 12:49:09 +0000 (15:49 +0300)
* metal : fix kernel_norm

ggml-ci

* metal : put warning in kernel_norm to not combine the loops

* metal : restore original F16 mat-vec multiplication

It works after the norm fixes

* common : don't do warm-up with more than n_batch tokens (close #3058)

ggml-ci

* metal : minor

common/common.cpp
ggml-metal.metal

index 22f65ac469b50c876f323df1606d6056de0ac983..28b7c6300fa514d9651d5a7ed5b5b5f2cad7e85b 100644 (file)
@@ -773,7 +773,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
         LOG("warming up the model with an empty run\n");
 
         const std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), };
-        llama_eval(lctx, tmp.data(), tmp.size(), 0, params.n_threads);
+        llama_eval(lctx, tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, params.n_threads);
         llama_reset_timings(lctx);
     }
 
index 119fcbeb623c11ca25a1c59c08564d7c3da32fba..d66ff340adec2cd9ef19576c1fb1f56e41de4ac6 100644 (file)
@@ -220,27 +220,32 @@ kernel void kernel_norm(
         }
         threadgroup_barrier(mem_flags::mem_threadgroup);
     }
-    //// broadcast
-    //if (tpitg == 0) {
-    //    sum[0] /= ne00;
-    //}
-    //threadgroup_barrier(mem_flags::mem_threadgroup);
+    // broadcast
+    if (tpitg == 0) {
+        sum[0] /= ne00;
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
     const float mean  = sum[0];
 
-    // recenter and VARIANCE
+    // recenter
     device float * y = dst + tgpig*ne00;
-    sum[tpitg] = 0.0f;
     for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
         y[i00] = x[i00] - mean;
+    }
+
+    // VARIANCE
+    // parallel sum
+    //
+    // WARNING: combining this loop with the one above will give you wrong results for nth == 256
+    //          I have no idea why, so for now I am keeping them separate. But this behavior is very concerning.
+    //          Tested with:
+    //          ./perplexity -m ./falcon-7b/ggml-model-q4_0.gguf -f wiki.test.raw -ngl 1 -t 4
+    //
+    sum[tpitg] = 0.0f;
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
         sum[tpitg] += y[i00] * y[i00];
     }
 
-    //// VARIANCE
-    //// parallel sum
-    //sum[tpitg] = 0.0f;
-    //for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
-    //    sum[tpitg] += y[i00] * y[i00];
-    //}
     // reduce
     threadgroup_barrier(mem_flags::mem_threadgroup);
     for (uint i = ntg/2; i > 0; i /= 2) {
@@ -249,11 +254,11 @@ kernel void kernel_norm(
         }
         threadgroup_barrier(mem_flags::mem_threadgroup);
     }
-    //// broadcast
-    //if (tpitg == 0) {
-    //    sum[0] /= ne00;
-    //}
-    //threadgroup_barrier(mem_flags::mem_threadgroup);
+    // broadcast
+    if (tpitg == 0) {
+        sum[0] /= ne00;
+    }
+    threadgroup_barrier(mem_flags::mem_threadgroup);
     const float variance = sum[0];
 
     const float scale = 1.0f/sqrt(variance + eps);
@@ -262,7 +267,6 @@ kernel void kernel_norm(
     }
 }
 
-
 kernel void kernel_rms_norm(
         device const  void * src0,
         device       float * dst,
@@ -630,7 +634,6 @@ kernel void kernel_mul_mat_f16_f32(
             }
         }
     }
-
 }
 
 kernel void kernel_alibi_f32(