]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: fix im2col_3d to respect non-contiguous inputs (views) (llama/15956)
authorJake Karnes <redacted>
Mon, 15 Sep 2025 22:28:31 +0000 (16:28 -0600)
committerGeorgi Gerganov <redacted>
Sat, 20 Sep 2025 10:33:50 +0000 (13:33 +0300)
* fix im2col_3d to respect non-contiguous inputs (views)

The CUDA 3D im2col kernel computed source addresses assuming compact layout (products of dims), ignoring nb[] strides.

This patch switches im2col_3d source indexing to use true strides derived from src1->nb[] (in elements), mirroring the approach used in the 2D CUDA im2col path. Destination indexing is unchanged.

* use ggml_element_size() for src strides

Co-authored-by: Johannes Gäßler <redacted>
---------

Co-authored-by: Johannes Gäßler <redacted>
src/ggml-cuda/im2col.cu

index 7737d6a5d5230ad112c825441ec59a31e6c63a3e..56dc0545742e4745e6747abde8f9ac201d8c2c31 100644 (file)
@@ -122,11 +122,14 @@ static  __global__ void im2col_3d_kernel(
         int64_t OH_OW, int64_t KD_KH_KW, int64_t ID_IH_IW, int64_t KH_KW, int64_t IH_IW, int64_t IC_ID_IH_IW,
         int64_t IC_KD_KH_KW, int64_t OW_KD_KH_KW, int64_t OD_OH_OW_IC_KD_KH_KW, int64_t OH_OW_IC_KD_KH_KW,
         int64_t OW_IC_KD_KH_KW, int64_t N_OD_OH, int64_t OD_OH,
+        int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x,
         int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2) {
     const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
     if (i >= IC_KD_KH_KW) {
         return;
     }
+    GGML_UNUSED(N); GGML_UNUSED(OC); GGML_UNUSED(OH_OW); GGML_UNUSED(OD); GGML_UNUSED(OW); GGML_UNUSED(KD); GGML_UNUSED(KH);
+    GGML_UNUSED(ID_IH_IW); GGML_UNUSED(IH_IW); GGML_UNUSED(IC_ID_IH_IW); GGML_UNUSED(OW_KD_KH_KW);
 
     const int64_t iic = i / KD_KH_KW;
     const int64_t ikd = (i - iic * KD_KH_KW) / KH_KW;
@@ -148,7 +151,7 @@ static  __global__ void im2col_3d_kernel(
         if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
             dst[offset_dst] = 0.0f;
         } else {
-            const int64_t offset_src = in*IC_ID_IH_IW + iic*ID_IH_IW + iid*IH_IW + iih*IW + iiw;
+            const int64_t offset_src = ((in * IC + iic) * stride_q) + (iid * stride_z) + (iih * stride_y) + (iiw * stride_x);
             dst[offset_dst] = src[offset_src];
         }
     }
@@ -159,6 +162,7 @@ template <typename T>
 static void im2col_3d_cuda(const float * src, T* dst,
     int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
     int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
+    int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x,
     int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
     const int64_t OH_OW = OH*OW;
     const int64_t KD_KH_KW = KD*KH*KW;
@@ -179,23 +183,30 @@ static void im2col_3d_cuda(const float * src, T* dst,
                                                                                            OH_OW, KD_KH_KW, ID_IH_IW, KH_KW, IH_IW, IC_ID_IH_IW,
                                                                                            IC_KD_KH_KW, OW_KD_KH_KW, OD_OH_OW_IC_KD_KH_KW,
                                                                                            OH_OW_IC_KD_KH_KW, OW_IC_KD_KH_KW, N_OD_OH, OD_OH,
+                                                                                           stride_q, stride_z, stride_y, stride_x,
                                                                                            s0, s1, s2, p0, p1, p2, d0, d1, d2);
 }
 
 static void im2col_3d_cuda_f16(const float * src, half * dst,
     int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
     int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
+    int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x,
     int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
 
-    im2col_3d_cuda<half>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
+    im2col_3d_cuda<half>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
+                         stride_q, stride_z, stride_y, stride_x,
+                         s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
 }
 
 static void im2col_3d_cuda_f32(const float * src, float * dst,
     int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
     int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
+    int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x,
     int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
 
-    im2col_3d_cuda<float>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
+    im2col_3d_cuda<float>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
+                          stride_q, stride_z, stride_y, stride_x,
+                          s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
 }
 
 void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@@ -235,9 +246,19 @@ void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
     const int64_t OH = ne2;
     const int64_t OW = ne1;
 
+    const size_t  es       = ggml_element_size(src1);
+    const int64_t stride_x = src1->nb[0] / es;
+    const int64_t stride_y = src1->nb[1] / es;
+    const int64_t stride_z = src1->nb[2] / es;
+    const int64_t stride_q = src1->nb[3] / es;
+
     if(dst->type == GGML_TYPE_F16) {
-        im2col_3d_cuda_f16(src1_d, (half *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
+        im2col_3d_cuda_f16(src1_d, (half *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
+                           stride_q, stride_z, stride_y, stride_x,
+                           s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
     } else {
-        im2col_3d_cuda_f32(src1_d, (float *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
+        im2col_3d_cuda_f32(src1_d, (float *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
+                           stride_q, stride_z, stride_y, stride_x,
+                           s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
     }
 }