]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : fox offset integer overflows in im2col (ggml/1015)
authorPlamen Minev <redacted>
Mon, 18 Nov 2024 13:02:27 +0000 (15:02 +0200)
committerGeorgi Gerganov <redacted>
Tue, 19 Nov 2024 18:03:21 +0000 (20:03 +0200)
-- While running StableDiffusion.cpp locally with Metal some offsets overflow and results in incorrect calculations

ggml/src/ggml-metal/ggml-metal.metal

index 819b20ba89bec6ffc79696680306716e6d4eb324..971f5054bce2a6384bcebaf8619d8a7f9eb56ca2 100644 (file)
@@ -2145,20 +2145,34 @@ kernel void kernel_im2col(
         uint3  tgpg[[threadgroups_per_grid]],
         uint3 tpitg[[thread_position_in_threadgroup]],
         uint3   ntg[[threads_per_threadgroup]]) {
-    const int32_t iiw = tgpig[2] * s0 + tpitg[2] * d0 - p0;
-    const int32_t iih = tgpig[1] * s1 + tpitg[1] * d1 - p1;
+//    const int64_t IC = tgpg[0];
+    const int64_t OH = tgpg[1];
+    const int64_t OW = tgpg[2];
 
-    const int32_t offset_dst =
-        (tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
-        (tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
+//    const int64_t N  = ntg[0];
+    const int64_t KH = ntg[1];
+    const int64_t KW = ntg[2];
+
+    const int64_t in  = tpitg[0];
+    const int64_t ikh = tpitg[1];
+    const int64_t ikw = tpitg[2];
+
+    const int64_t iic = tgpig[0];
+    const int64_t ioh = tgpig[1];
+    const int64_t iow = tgpig[2];
+
+    const int64_t iiw = iow*s0 + ikw*d0 - p0;
+    const int64_t iih = ioh*s1 + ikh*d1 - p1;
+
+    const int64_t offset_dst = (in*OH*OW + ioh*OW + iow)*CHW + (iic*(KH*KW) + ikh*KW + ikw);
 
     device T * pdst = (device T *) (dst);
 
     if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
         pdst[offset_dst] = 0.0f;
     } else {
-        const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
-        pdst[offset_dst] = x[offset_src + iih * IW + iiw];
+        const int64_t offset_src = in*ofs0 + iic*ofs1 + iih*IW + iiw;
+        pdst[offset_dst] = x[offset_src];
     }
 }
 
@@ -2209,25 +2223,25 @@ kernel void kernel_im2col_ext(
         uint3  tgpg[[threadgroups_per_grid]],      // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
         uint3 tpitg[[thread_position_in_threadgroup]],
         uint3   ntg[[threads_per_threadgroup]]) {  // [M, 1, 1]
-    const int32_t KHW = KH * KW;             // KHW == ntg[1] * ntg[2], KW == ntg[2]
+    const int64_t KHW = KH * KW;             // KHW == ntg[1] * ntg[2], KW == ntg[2]
 
-    const int32_t d = tgpig[0] / CHW;
-    const int32_t chw = tgpig[0] % CHW;
-    const int32_t tgpig_0 = chw / KHW;  // 0 ~ (IC - 1)
-    const int32_t HW = tgpig[0] % KHW;
+    const int64_t d = tgpig[0] / CHW;
+    const int64_t chw = tgpig[0] % CHW;
+    const int64_t tgpig_0 = chw / KHW;  // 0 ~ (IC - 1)
+    const int64_t HW = tgpig[0] % KHW;
 
-    const int32_t tpitg_0 = (d * ntg[0]) + tpitg[0];
+    const int64_t tpitg_0 = (d * ntg[0]) + tpitg[0];
     if (tpitg_0 >= N) {
         return;
     }
 
-    const int32_t tpitg_1 = HW / KW;
-    const int32_t tpitg_2 = HW % KW;
+    const int64_t tpitg_1 = HW / KW;
+    const int64_t tpitg_2 = HW % KW;
 
-    const int32_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0;
-    const int32_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1;
+    const int64_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0;
+    const int64_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1;
 
-    const int32_t offset_dst =
+    const int64_t offset_dst =
         (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
         (tgpig_0 * KHW + tpitg_1 * KW + tpitg_2);
 
@@ -2236,7 +2250,7 @@ kernel void kernel_im2col_ext(
     if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
         pdst[offset_dst] = 0.0f;
     } else {
-        const int32_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1;
+        const int64_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1;
         pdst[offset_dst] = x[offset_src + iih * IW + iiw];
     }
 }