]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
cuda : fix copy of large tensors (ggml_nbytes <= INT_MAX assertion) (#18433)
authorMeeMin <redacted>
Thu, 1 Jan 2026 23:24:20 +0000 (04:54 +0530)
committerGitHub <redacted>
Thu, 1 Jan 2026 23:24:20 +0000 (00:24 +0100)
* ggml-cuda: fixed assertion in ggml_cuda_cpy (#18140)

* ggml-cuda: changes in data types to int64_t

* ggml-cuda: added asserts for CUDA block numbers

* ggml-cuda: changed the condition for y and z dimension

ggml/src/ggml-cuda/cpy.cu

index c4ceb4fc5794d26e97a48757e202987f43a0b6ef..ee84303ef0e99df8dc63f556e4ebdb7c7c6e3aa9 100644 (file)
@@ -12,11 +12,11 @@ const int CUDA_CPY_BLOCK_NM = 8;     // block size of 3rd dimension if available
 const int CUDA_CPY_BLOCK_ROWS = 8;   // block dimension for marching through rows
 
 template <cpy_kernel_t cpy_1>
-static __global__ void cpy_scalar(const char * cx, char * cdst, const int ne,
-                                  const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-                                  const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
-                                  const int nb12, const int nb13) {
-    const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
+static __global__ void cpy_scalar(const char * cx, char * cdst, const int64_t ne,
+                                  const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
+                                  const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
+                                  const int64_t nb12, const int64_t nb13) {
+    const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i >= ne) {
         return;
@@ -40,10 +40,10 @@ static __global__ void cpy_scalar(const char * cx, char * cdst, const int ne,
 }
 
 template <typename T>
-static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const int ne,
-                               const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-                               const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
-                               const int nb12, const int nb13) {
+static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const int64_t ne,
+                               const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
+                               const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
+                               const int64_t nb12, const int64_t nb13) {
 
     const T* src = reinterpret_cast<const T*>(cx);
     T* dst = reinterpret_cast<T*>(cdst);
@@ -117,60 +117,60 @@ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
 }
 
 template <cpy_kernel_t cpy_blck, int qk>
-static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
-                                 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-                                 const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
-                                 const int nb12, const int nb13) {
-    const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
+static __global__ void cpy_f32_q(const char * cx, char * cdst, const int64_t ne,
+                                 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
+                                 const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
+                                 const int64_t nb12, const int64_t nb13) {
+    const int64_t i = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*qk;
 
     if (i >= ne) {
         return;
     }
 
-    const int i03 = i/(ne00 * ne01 * ne02);
-    const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
-    const int i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
-    const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
-    const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
+    const int64_t i03 = i/(ne00 * ne01 * ne02);
+    const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
+    const int64_t i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
+    const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
+    const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
 
-    const int i13 = i/(ne10 * ne11 * ne12);
-    const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
-    const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
-    const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
-    const int dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
+    const int64_t i13 = i/(ne10 * ne11 * ne12);
+    const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
+    const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
+    const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
+    const int64_t dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
 
     cpy_blck(cx + x_offset, cdst + dst_offset);
 }
 
 template <cpy_kernel_t cpy_blck, int qk>
-static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
-                                 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-                                 const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
-                                 const int nb12, const int nb13) {
-    const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
+static __global__ void cpy_q_f32(const char * cx, char * cdst, const int64_t ne,
+                                 const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
+                                 const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11,
+                                 const int64_t nb12, const int64_t nb13) {
+    const int64_t i = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*qk;
 
     if (i >= ne) {
         return;
     }
 
-    const int i03 = i/(ne00 * ne01 * ne02);
-    const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
-    const int i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
-    const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
-    const int x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
+    const int64_t i03 = i/(ne00 * ne01 * ne02);
+    const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
+    const int64_t i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
+    const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
+    const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
 
-    const int i13 = i/(ne10 * ne11 * ne12);
-    const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
-    const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
-    const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
-    const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
+    const int64_t i13 = i/(ne10 * ne11 * ne12);
+    const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
+    const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
+    const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
+    const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
 
     cpy_blck(cx + x_offset, cdst + dst_offset);
 }
 
 template<typename src_t, typename dst_t>
 static __global__ void cpy_scalar_contiguous(const char * cx, char * cdst, const int64_t ne) {
-    const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
+    const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i >= ne) {
         return;
@@ -188,19 +188,20 @@ static void ggml_cpy_scalar_contiguous_cuda(
 cudaStream_t stream) {
 
     const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+    GGML_ASSERT(num_blocks < UINT_MAX);
     cpy_scalar_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
         (cx, cdst, ne);
 }
 
 template<typename src_t, typename dst_t, bool transposed = false>
 static void ggml_cpy_scalar_cuda(
-    const char * cx, char * cdst, const int ne,
-    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const char * cx, char * cdst, const int64_t ne,
+    const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
+    const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
 
     if (transposed) {
         GGML_ASSERT(ne == ne00*ne01*ne02);  // ne[3] is 1 assumed
-        int ne00n, ne01n, ne02n;
+        int64_t ne00n, ne01n, ne02n;
         if (nb00 <= nb02) { // most likely safe to handle nb00 = nb02 case here
             ne00n = ne00;
             ne01n = ne01;
@@ -211,143 +212,159 @@ static void ggml_cpy_scalar_cuda(
             ne02n = 1;
         }
 
-        dim3 dimGrid( (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
-                      (ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
-                      (ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM);
+        int64_t grid_x = (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D;
+        int64_t grid_y = (ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D;
+        int64_t grid_z = (ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM;
+        GGML_ASSERT(grid_x < UINT_MAX);
+        GGML_ASSERT(grid_y < USHRT_MAX);
+        GGML_ASSERT(grid_z < USHRT_MAX);
+        dim3 dimGrid(grid_x, grid_y, grid_z);
         dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
         cpy_scalar_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
             (cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
     } else {
-        const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+        const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+        GGML_ASSERT(num_blocks < UINT_MAX);
         cpy_scalar<cpy_1_scalar<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
             (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
     }
 }
 
 static void ggml_cpy_f32_q8_0_cuda(
-    const char * cx, char * cdst, const int ne,
-    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const char * cx, char * cdst, const int64_t ne,
+    const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
+    const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
 
     GGML_ASSERT(ne % QK8_0 == 0);
-    const int num_blocks = ne / QK8_0;
+    const int64_t num_blocks = ne / QK8_0;
+    GGML_ASSERT(num_blocks < UINT_MAX);
     cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
         (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
 static void ggml_cpy_q8_0_f32_cuda(
-    const char * cx, char * cdst, const int ne,
-    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const char * cx, char * cdst, const int64_t ne,
+    const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
+    const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
 
-    const int num_blocks = ne;
+    const int64_t num_blocks = ne;
+    GGML_ASSERT(num_blocks < UINT_MAX);
     cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>>
         (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
 static void ggml_cpy_f32_q4_0_cuda(
-    const char * cx, char * cdst, const int ne,
-    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const char * cx, char * cdst, const int64_t ne,
+    const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
+    const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
 
     GGML_ASSERT(ne % QK4_0 == 0);
-    const int num_blocks = ne / QK4_0;
+    const int64_t num_blocks = ne / QK4_0;
+    GGML_ASSERT(num_blocks < UINT_MAX);
     cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
         (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
 static void ggml_cpy_q4_0_f32_cuda(
-    const char * cx, char * cdst, const int ne,
-    const int ne00, const int ne01, const int ne02,
-    const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12,
-    const int nb10, const int nb11, const int nb12, const int nb13,
+    const char * cx, char * cdst, const int64_t ne,
+    const int64_t ne00, const int64_t ne01, const int64_t ne02,
+    const int64_t nb00, const int64_t nb01, const int64_t nb02,
+    const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,
+    const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
     cudaStream_t stream) {
-    const int num_blocks = ne;
+    const int64_t num_blocks = ne;
+    GGML_ASSERT(num_blocks < UINT_MAX);
     cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>(
         cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
          ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
 static void ggml_cpy_f32_q4_1_cuda(
-    const char * cx, char * cdst, const int ne,
-    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const char * cx, char * cdst, const int64_t ne,
+    const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
+    const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
 
     GGML_ASSERT(ne % QK4_1 == 0);
-    const int num_blocks = ne / QK4_1;
+    const int64_t num_blocks = ne / QK4_1;
+    GGML_ASSERT(num_blocks < UINT_MAX);
     cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
         (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
 static void ggml_cpy_q4_1_f32_cuda(
-    const char * cx, char * cdst, const int ne,
-    const int ne00, const int ne01, const int ne02,
-    const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12,
-    const int nb10, const int nb11, const int nb12, const int nb13,
+    const char * cx, char * cdst, const int64_t ne,
+    const int64_t ne00, const int64_t ne01, const int64_t ne02,
+    const int64_t nb00, const int64_t nb01, const int64_t nb02,
+    const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,
+    const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
     cudaStream_t stream) {
-    const int num_blocks = ne;
+    const int64_t num_blocks = ne;
+    GGML_ASSERT(num_blocks < UINT_MAX);
     cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>(
         cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
          ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
 static void ggml_cpy_f32_q5_0_cuda(
-    const char * cx, char * cdst, const int ne,
-    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const char * cx, char * cdst, const int64_t ne,
+    const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
+    const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
 
     GGML_ASSERT(ne % QK5_0 == 0);
-    const int num_blocks = ne / QK5_0;
+    const int64_t num_blocks = ne / QK5_0;
+    GGML_ASSERT(num_blocks < UINT_MAX);
     cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>
         (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
 static void ggml_cpy_q5_0_f32_cuda(
-    const char * cx, char * cdst, const int ne,
-    const int ne00, const int ne01, const int ne02,
-    const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12,
-    const int nb10, const int nb11, const int nb12, const int nb13,
+    const char * cx, char * cdst, const int64_t ne,
+    const int64_t ne00, const int64_t ne01, const int64_t ne02,
+    const int64_t nb00, const int64_t nb01, const int64_t nb02,
+    const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,
+    const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
     cudaStream_t stream) {
-    const int num_blocks = ne;
+    const int64_t num_blocks = ne;
+    GGML_ASSERT(num_blocks < UINT_MAX);
     cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>(
         cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
         ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
 static void ggml_cpy_f32_q5_1_cuda(
-    const char * cx, char * cdst, const int ne,
-    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const char * cx, char * cdst, const int64_t ne,
+    const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
+    const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
 
     GGML_ASSERT(ne % QK5_1 == 0);
-    const int num_blocks = ne / QK5_1;
+    const int64_t num_blocks = ne / QK5_1;
+    GGML_ASSERT(num_blocks < UINT_MAX);
     cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
         (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
 static void ggml_cpy_q5_1_f32_cuda(
-    const char * cx, char * cdst, const int ne,
-    const int ne00, const int ne01, const int ne02,
-    const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12,
-    const int nb10, const int nb11, const int nb12, const int nb13,
+    const char * cx, char * cdst, const int64_t ne,
+    const int64_t ne00, const int64_t ne01, const int64_t ne02,
+    const int64_t nb00, const int64_t nb01, const int64_t nb02,
+    const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12,
+    const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13,
     cudaStream_t stream) {
-    const int num_blocks = ne;
+    const int64_t num_blocks = ne;
+    GGML_ASSERT(num_blocks < UINT_MAX);
     cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>(
         cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
         ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
 static void ggml_cpy_f32_iq4_nl_cuda(
-    const char * cx, char * cdst, const int ne,
-    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+    const char * cx, char * cdst, const int64_t ne,
+    const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02,
+    const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) {
 
     GGML_ASSERT(ne % QK4_NL == 0);
-    const int num_blocks = ne / QK4_NL;
+    const int64_t num_blocks = ne / QK4_NL;
+    GGML_ASSERT(num_blocks < UINT_MAX);
     cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
         (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
@@ -356,9 +373,6 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
     const int64_t ne = ggml_nelements(src0);
     GGML_ASSERT(ne == ggml_nelements(src1));
 
-    GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
-    GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
-
     const int64_t ne00 = src0->ne[0];
     const int64_t ne01 = src0->ne[1];
     const int64_t ne02 = src0->ne[2];