]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
metal : fix out-of-bounds write (llama/11314)
authorGeorgi Gerganov <redacted>
Tue, 21 Jan 2025 06:48:13 +0000 (08:48 +0200)
committerGeorgi Gerganov <redacted>
Mon, 3 Feb 2025 20:00:57 +0000 (22:00 +0200)
ggml-ci

ggml/src/ggml-metal/ggml-metal.metal

index 8ba43904d0c1b2760f962110125b39457318c6cb..44f04c909bfb2d59c4e17ac9b938275467e79ed7 100644 (file)
@@ -4416,7 +4416,6 @@ void kernel_mul_mv_q2_K_f32_impl(
         device const half     * dh = &x[ib].d;
 
         for (int row = 0; row < N_DST; row++) {
-
             float4 acc1 = {0.f, 0.f, 0.f, 0.f};
             float4 acc2 = {0.f, 0.f, 0.f, 0.f};
             for (int i = 0; i < 8; i += 2) {
@@ -4447,7 +4446,7 @@ void kernel_mul_mv_q2_K_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum;
@@ -4613,7 +4612,7 @@ void kernel_mul_mv_q3_K_f32_impl(
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
     if (tiisg == 0) {
-        for (int row = 0; row < 2; ++row) {
+        for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
             dst_f32[first_row + row] = sumf1[row];
         }
     }
@@ -4729,7 +4728,7 @@ void kernel_mul_mv_q4_K_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum;
@@ -4861,7 +4860,7 @@ void kernel_mul_mv_q5_K_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < 2; ++row) {
+    for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
         const float tot = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = tot;
@@ -4906,6 +4905,10 @@ void kernel_mul_mv_q6_K_f32_impl(
 
     const int row = 2*r0 + sgitg;
 
+    if (row >= args.ne0) {
+        return;
+    }
+
     const uint i12 = im%args.ne12;
     const uint i13 = im/args.ne12;
 
@@ -5061,7 +5064,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum * 0.25f;
@@ -5179,7 +5182,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum * 0.25f;
@@ -5289,7 +5292,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum * 0.5f;
@@ -5401,7 +5404,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum;
@@ -5514,7 +5517,7 @@ void kernel_mul_mv_iq2_s_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum * 0.25f;
@@ -5614,7 +5617,7 @@ void kernel_mul_mv_iq1_s_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum;
@@ -5709,7 +5712,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < N_DST; ++row) {
+    for (int row = 0; row < N_DST && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum;
@@ -5799,7 +5802,7 @@ void kernel_mul_mv_iq4_nl_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) {
+    for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum;
@@ -5888,7 +5891,7 @@ void kernel_mul_mv_iq4_xs_f32_impl(
 
     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
 
-    for (int row = 0; row < 2; ++row) {
+    for (int row = 0; row < 2 && first_row + row < args.ne0; ++row) {
         all_sum = simd_sum(sumf[row]);
         if (tiisg == 0) {
             dst_f32[first_row + row] = all_sum;