From: Jake Karnes Date: Mon, 15 Sep 2025 22:28:31 +0000 (-0600) Subject: CUDA: fix im2col_3d to respect non-contiguous inputs (views) (llama/15956) X-Git-Tag: v0.9.1~25 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=f771037ffe29b2c666ab5716f7d5e42cde1ec746;p=pkg%2Fggml%2Fsources%2Fggml CUDA: fix im2col_3d to respect non-contiguous inputs (views) (llama/15956) * fix im2col_3d to respect non-contiguous inputs (views) The CUDA 3D im2col kernel computed source addresses assuming compact layout (products of dims), ignoring nb[] strides. This patch switches im2col_3d source indexing to use true strides derived from src1->nb[] (in elements), mirroring the approach used in the 2D CUDA im2col path. Destination indexing is unchanged. * use ggml_element_size() for src strides Co-authored-by: Johannes Gäßler --------- Co-authored-by: Johannes Gäßler --- diff --git a/src/ggml-cuda/im2col.cu b/src/ggml-cuda/im2col.cu index 7737d6a5..56dc0545 100644 --- a/src/ggml-cuda/im2col.cu +++ b/src/ggml-cuda/im2col.cu @@ -122,11 +122,14 @@ static __global__ void im2col_3d_kernel( int64_t OH_OW, int64_t KD_KH_KW, int64_t ID_IH_IW, int64_t KH_KW, int64_t IH_IW, int64_t IC_ID_IH_IW, int64_t IC_KD_KH_KW, int64_t OW_KD_KH_KW, int64_t OD_OH_OW_IC_KD_KH_KW, int64_t OH_OW_IC_KD_KH_KW, int64_t OW_IC_KD_KH_KW, int64_t N_OD_OH, int64_t OD_OH, + int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x, int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2) { const int64_t i = threadIdx.x + blockIdx.x * blockDim.x; if (i >= IC_KD_KH_KW) { return; } + GGML_UNUSED(N); GGML_UNUSED(OC); GGML_UNUSED(OH_OW); GGML_UNUSED(OD); GGML_UNUSED(OW); GGML_UNUSED(KD); GGML_UNUSED(KH); + GGML_UNUSED(ID_IH_IW); GGML_UNUSED(IH_IW); GGML_UNUSED(IC_ID_IH_IW); GGML_UNUSED(OW_KD_KH_KW); const int64_t iic = i / KD_KH_KW; const int64_t ikd = (i - iic * KD_KH_KW) / KH_KW; @@ -148,7 +151,7 @@ static __global__ void im2col_3d_kernel( if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) { dst[offset_dst] = 0.0f; } else { - const int64_t offset_src = in*IC_ID_IH_IW + iic*ID_IH_IW + iid*IH_IW + iih*IW + iiw; + const int64_t offset_src = ((in * IC + iic) * stride_q) + (iid * stride_z) + (iih * stride_y) + (iiw * stride_x); dst[offset_dst] = src[offset_src]; } } @@ -159,6 +162,7 @@ template static void im2col_3d_cuda(const float * src, T* dst, int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC, int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW, + int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x, int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) { const int64_t OH_OW = OH*OW; const int64_t KD_KH_KW = KD*KH*KW; @@ -179,23 +183,30 @@ static void im2col_3d_cuda(const float * src, T* dst, OH_OW, KD_KH_KW, ID_IH_IW, KH_KW, IH_IW, IC_ID_IH_IW, IC_KD_KH_KW, OW_KD_KH_KW, OD_OH_OW_IC_KD_KH_KW, OH_OW_IC_KD_KH_KW, OW_IC_KD_KH_KW, N_OD_OH, OD_OH, + stride_q, stride_z, stride_y, stride_x, s0, s1, s2, p0, p1, p2, d0, d1, d2); } static void im2col_3d_cuda_f16(const float * src, half * dst, int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC, int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW, + int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x, int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) { - im2col_3d_cuda(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream); + im2col_3d_cuda(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, + stride_q, stride_z, stride_y, stride_x, + s0, s1, s2, p0, p1, p2, d0, d1, d2, stream); } static void im2col_3d_cuda_f32(const float * src, float * dst, int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC, int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW, + int64_t stride_q, int64_t stride_z, int64_t stride_y, int64_t stride_x, int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) { - im2col_3d_cuda(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream); + im2col_3d_cuda(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, + stride_q, stride_z, stride_y, stride_x, + s0, s1, s2, p0, p1, p2, d0, d1, d2, stream); } void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -235,9 +246,19 @@ void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) const int64_t OH = ne2; const int64_t OW = ne1; + const size_t es = ggml_element_size(src1); + const int64_t stride_x = src1->nb[0] / es; + const int64_t stride_y = src1->nb[1] / es; + const int64_t stride_z = src1->nb[2] / es; + const int64_t stride_q = src1->nb[3] / es; + if(dst->type == GGML_TYPE_F16) { - im2col_3d_cuda_f16(src1_d, (half *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream); + im2col_3d_cuda_f16(src1_d, (half *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, + stride_q, stride_z, stride_y, stride_x, + s0, s1, s2, p0, p1, p2, d0, d1, d2, stream); } else { - im2col_3d_cuda_f32(src1_d, (float *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream); + im2col_3d_cuda_f32(src1_d, (float *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, + stride_q, stride_z, stride_y, stride_x, + s0, s1, s2, p0, p1, p2, d0, d1, d2, stream); } }