]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: fix GET_ROWS for large tensors (llama/15882)
authorJohannes Gäßler <redacted>
Tue, 9 Sep 2025 06:11:01 +0000 (08:11 +0200)
committerGeorgi Gerganov <redacted>
Sat, 20 Sep 2025 10:33:50 +0000 (13:33 +0300)
src/ggml-cuda/getrows.cu
src/ggml-cuda/ggml-cuda.cu

index 83d02474f5d4851cb099ff6cdfa045e4cb4cd54d..2fab33243ddadce3e2b23a20361d43492106272e 100644 (file)
@@ -2,39 +2,39 @@
 #include "dequantize.cuh"
 #include "convert.cuh"
 
-#define MAX_GRIDDIM_Y 65535
-
 template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
 static __global__ void k_get_rows(
         const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
         const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
-        /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
+        /*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/
         /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
         /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
         const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
 
-    for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) {
-        // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
-        const int i10 =  blockIdx.x;
-        const int i11 =  blockIdx.z / ne12;
-        const int i12 =  blockIdx.z % ne12;
+    for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) {
+        for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) {
+            // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
+            const int i10 =  blockIdx.x;
+            const int i11 =  z / ne12; // TODO fastdiv
+            const int i12 =  z % ne12;
 
-        const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
+            const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
 
-        dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
-        const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03;
+            dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+            const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03;
 
-        const int ib   =  i00/qk;      // block index
-        const int iqs  = (i00%qk)/qr;  // quant index
-        const int iybs = i00 - i00%qk; // dst block start index
-        const int y_offset = qr == 1 ? 1 : qk/2;
+            const int ib   =  i00/qk;      // block index
+            const int iqs  = (i00%qk)/qr;  // quant index
+            const int iybs = i00 - i00%qk; // dst block start index
+            const int y_offset = qr == 1 ? 1 : qk/2;
 
-        // dequantize
-        float2 v;
-        dequantize_kernel(src0_row, ib, iqs, v);
+            // dequantize
+            float2 v;
+            dequantize_kernel(src0_row, ib, iqs, v);
 
-        dst_row[iybs + iqs + 0]        = ggml_cuda_cast<dst_t>(v.x);
-        dst_row[iybs + iqs + y_offset] = ggml_cuda_cast<dst_t>(v.y);
+            dst_row[iybs + iqs + 0]        = ggml_cuda_cast<dst_t>(v.x);
+            dst_row[iybs + iqs + y_offset] = ggml_cuda_cast<dst_t>(v.y);
+        }
     }
 }
 
@@ -42,27 +42,29 @@ template<typename src0_t, typename dst_t>
 static __global__ void k_get_rows_float(
         const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
         const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
-        /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
+        /*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/
         /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
         /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
         const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
 
-    for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) {
-        // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
-        const int i10 = blockIdx.x;
-        const int i11 = blockIdx.z / ne12;
-        const int i12 = blockIdx.z % ne12;
+    for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) {
+        for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) {
+            // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
+            const int i10 = blockIdx.x;
+            const int i11 = z / ne12; // TODO fastdiv
+            const int i12 = z % ne12;
 
-        if (i00 >= ne00) {
-            return;
-        }
+            if (i00 >= ne00) {
+                return;
+            }
 
-        const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
+            const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
 
-        dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
-        const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
+            dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+            const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
 
-        dst_row[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
+            dst_row[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
+        }
     }
 }
 
@@ -98,7 +100,7 @@ static void get_rows_cuda_q(
         cudaStream_t stream) {
     const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
     const int block_num_y = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
-    const dim3 block_nums(ne10, MIN(block_num_y, MAX_GRIDDIM_Y), ne11*ne12);
+    const dim3 block_nums(ne10, MIN(block_num_y, UINT16_MAX), MIN(ne11*ne12, UINT16_MAX));
 
     // strides in elements
     // const size_t s0 = nb0 / sizeof(dst_t);
@@ -116,7 +118,7 @@ static void get_rows_cuda_q(
     k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
         src0_d, src1_d, dst_d,
         ne00, /*ne01, ne02, ne03,*/
-        /*ne10, ne11,*/ ne12, /*ne13,*/
+        /*ne10,*/ ne11, ne12, /*ne13,*/
         /* s0,*/ s1, s2, s3,
         /* nb00,*/ nb01, nb02, nb03,
         s10, s11, s12/*, s13*/);
@@ -131,7 +133,7 @@ static void get_rows_cuda_float(
         cudaStream_t stream) {
     const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
     const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
-    const dim3 block_nums(ne10, MIN(block_num_y, MAX_GRIDDIM_Y), ne11*ne12);
+    const dim3 block_nums(ne10, MIN(block_num_y, UINT16_MAX), MIN(ne11*ne12, UINT16_MAX));
 
     // strides in elements
     // const size_t s0 = nb0 / sizeof(dst_t);
@@ -147,7 +149,7 @@ static void get_rows_cuda_float(
     k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
         src0_d, src1_d, dst_d,
         ne00, /*ne01, ne02, ne03,*/
-        /*ne10, ne11,*/ ne12, /*ne13,*/
+        /*ne10,*/ ne11, ne12, /*ne13,*/
         /* s0,*/ s1, s2, s3,
         /* nb00,*/ nb01, nb02, nb03,
         s10, s11, s12/*, s13*/);
index efca2c775a4006df3a06f27ad6e1e9831644b746..95170ae11bad559a5d9a3fbccbbbfb0468ac1313 100644 (file)
@@ -3393,10 +3393,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
             return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
         case GGML_OP_GET_ROWS:
             {
-                // FIXME: https://github.com/ggml-org/llama.cpp/pull/15868
-                if (op->src[1]->ne[1]*op->src[1]->ne[2] > 65535) {
-                    return false;
-                }
                 switch (op->src[0]->type) {
                     case GGML_TYPE_F16:
                     case GGML_TYPE_F32: