From: Georgi Gerganov Date: Fri, 15 Sep 2023 17:17:24 +0000 (+0300) Subject: metal : fix bug in soft_max kernels (out-of-bounds access) (#3194) X-Git-Tag: gguf-v0.4.0~64 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=c6f1491da032238241e01021c8c58d7b540a043f;p=pkg%2Fggml%2Fsources%2Fllama.cpp metal : fix bug in soft_max kernels (out-of-bounds access) (#3194) --- diff --git a/ggml-metal.metal b/ggml-metal.metal index 3087ecda..7f1c3d9e 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -118,7 +118,7 @@ kernel void kernel_soft_max( device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; // parallel max - float lmax = psrc0[tpitg[0]]; + float lmax = tpitg[0] < ne00 ? psrc0[tpitg[0]] : -INFINITY; for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) { lmax = MAX(lmax, psrc0[i00]); } @@ -158,7 +158,7 @@ kernel void kernel_soft_max_4( device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); // parallel max - float4 lmax4 = psrc4[tpitg[0]]; + float4 lmax4 = tpitg[0] < ne00/4 ? psrc4[tpitg[0]] : -INFINITY; for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) { lmax4 = fmax(lmax4, psrc4[i00]); }