]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : fix bug in soft_max kernels (out-of-bounds access) (#3194)
authorGeorgi Gerganov <redacted>
Fri, 15 Sep 2023 17:17:24 +0000 (20:17 +0300)
committerGitHub <redacted>
Fri, 15 Sep 2023 17:17:24 +0000 (20:17 +0300)
ggml-metal.metal

index 3087ecda812d90d510ffc6859ff3d4dc52f12e7e..7f1c3d9ea74bd031ae831bfe916ab8fa30d211a9 100644 (file)
@@ -118,7 +118,7 @@ kernel void kernel_soft_max(
     device       float * pdst  = dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
 
     // parallel max
-    float lmax = psrc0[tpitg[0]];
+    float lmax = tpitg[0] < ne00 ? psrc0[tpitg[0]] : -INFINITY;
     for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
         lmax = MAX(lmax, psrc0[i00]);
     }
@@ -158,7 +158,7 @@ kernel void kernel_soft_max_4(
     device       float4 * pdst4 = (device       float4 *)(dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
 
     // parallel max
-    float4 lmax4 = psrc4[tpitg[0]];
+    float4 lmax4 = tpitg[0] < ne00/4 ? psrc4[tpitg[0]] : -INFINITY;
     for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
         lmax4 = fmax(lmax4, psrc4[i00]);
     }