]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: fix crash on large batch size for MoE models (#13384)
authorJohannes Gäßler <redacted>
Fri, 9 May 2025 10:14:04 +0000 (12:14 +0200)
committerGitHub <redacted>
Fri, 9 May 2025 10:14:04 +0000 (12:14 +0200)
ggml/src/ggml-cuda/getrows.cu

index ea8bf691609966ced806b0682631e6ad486eedbf..963e4d03dd77b11aedc87479c7da96bf9487cd60 100644 (file)
@@ -10,10 +10,11 @@ static __global__ void k_get_rows(
         /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
         const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
 
-    const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2;
-    const int i10 =  blockDim.y*blockIdx.y + threadIdx.y;
-    const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
-    const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
+    // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
+    const int i00 = (blockIdx.y * blockDim.x + threadIdx.x)*2;
+    const int i10 =  blockIdx.x;
+    const int i11 =  blockIdx.z / ne12;
+    const int i12 =  blockIdx.z % ne12;
 
     if (i00 >= ne00) {
         return;
@@ -46,10 +47,11 @@ static __global__ void k_get_rows_float(
         /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
         const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
 
-    const int i00 =  blockIdx.x*blockDim.x + threadIdx.x;
-    const int i10 =  blockDim.y*blockIdx.y + threadIdx.y;
-    const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
-    const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
+    // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
+    const int i00 = blockIdx.y * blockDim.x + threadIdx.x;
+    const int i10 = blockIdx.x;
+    const int i11 = blockIdx.z / ne12;
+    const int i12 = blockIdx.z % ne12;
 
     if (i00 >= ne00) {
         return;
@@ -94,8 +96,8 @@ static void get_rows_cuda_q(
         const size_t nb1, const size_t nb2, const size_t nb3,
         cudaStream_t stream) {
     const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
-    const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
-    const dim3 block_nums(block_num_x, ne10, ne11*ne12);
+    const int block_num_y = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
+    const dim3 block_nums(ne10, block_num_y, ne11*ne12);
 
     // strides in elements
     // const size_t s0 = nb0 / sizeof(dst_t);
@@ -127,8 +129,8 @@ static void get_rows_cuda_float(
         const size_t nb1, const size_t nb2, const size_t nb3,
         cudaStream_t stream) {
     const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
-    const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
-    const dim3 block_nums(block_num_x, ne10, ne11*ne12);
+    const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
+    const dim3 block_nums(ne10, block_num_y, ne11*ne12);
 
     // strides in elements
     // const size_t s0 = nb0 / sizeof(dst_t);