]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
cuda : support non-contiguous i32 to i32 copy (llama/17326)
authorSigbjørn Skjæret <redacted>
Sun, 23 Nov 2025 10:13:34 +0000 (11:13 +0100)
committerGeorgi Gerganov <redacted>
Fri, 12 Dec 2025 15:53:06 +0000 (17:53 +0200)
* support non-contiguous i32 to i32 copy

* add tests

* rename cpy_flt to cpy_scalar and reindent params

ggml/src/ggml-cuda/cpy-utils.cuh
ggml/src/ggml-cuda/cpy.cu
ggml/src/ggml-cuda/ggml-cuda.cu

index e621cb9811ab62a12d32a848be1de66afdb77725..7697c292dd6f839286b5e246c8254aca5837ad51 100644 (file)
@@ -212,6 +212,6 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
 }
 
 template<typename src_t, typename dst_t>
-static __device__ void cpy_1_flt(const char * cxi, char * cdsti) {
+static __device__ void cpy_1_scalar(const char * cxi, char * cdsti) {
     *(dst_t *) cdsti = ggml_cuda_cast<dst_t>(*(const src_t *) cxi);
 }
index c1afde9627f09b8b7680281439e8a21ddd8412a3..82367ad3fb0d53e5b46711bde0e51a7f74a54591 100644 (file)
@@ -12,10 +12,10 @@ const int CUDA_CPY_BLOCK_NM = 8;     // block size of 3rd dimension if available
 const int CUDA_CPY_BLOCK_ROWS = 8;   // block dimension for marching through rows
 
 template <cpy_kernel_t cpy_1>
-static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
-                               const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-                               const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
-                               const int nb12, const int nb13) {
+static __global__ void cpy_scalar(const char * cx, char * cdst, const int ne,
+                                  const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+                                  const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
+                                  const int nb12, const int nb13) {
     const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i >= ne) {
@@ -40,7 +40,7 @@ static __global__ void cpy_flt(const char * cx, char * cdst, const int ne,
 }
 
 template <typename T>
-static __global__ void cpy_flt_transpose(const char * cx, char * cdst, const int ne,
+static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const int ne,
                                const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
                                const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
                                const int nb12, const int nb13) {
@@ -166,7 +166,7 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
 }
 
 template<typename src_t, typename dst_t>
-static __global__ void cpy_flt_contiguous(const char * cx, char * cdst, const int64_t ne) {
+static __global__ void cpy_scalar_contiguous(const char * cx, char * cdst, const int64_t ne) {
     const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i >= ne) {
@@ -180,17 +180,17 @@ static __global__ void cpy_flt_contiguous(const char * cx, char * cdst, const in
 }
 
 template<typename src_t, typename dst_t>
-static void ggml_cpy_flt_contiguous_cuda(
+static void ggml_cpy_scalar_contiguous_cuda(
     const char * cx, char * cdst, const int64_t ne,
 cudaStream_t stream) {
 
     const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
-    cpy_flt_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
+    cpy_scalar_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
         (cx, cdst, ne);
 }
 
 template<typename src_t, typename dst_t, bool transposed = false>
-static void ggml_cpy_flt_cuda(
+static void ggml_cpy_scalar_cuda(
     const char * cx, char * cdst, const int ne,
     const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
     const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
@@ -212,11 +212,11 @@ static void ggml_cpy_flt_cuda(
                       (ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
                       (ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM);
         dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
-        cpy_flt_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
+        cpy_scalar_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
             (cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
     } else {
         const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
-        cpy_flt<cpy_1_flt<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
+        cpy_scalar<cpy_1_scalar<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
             (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
     }
 }
@@ -399,94 +399,132 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
         }
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
         if (can_be_transposed) {
-            ggml_cpy_flt_cuda<float, float, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+            ggml_cpy_scalar_cuda<float, float, true>
+                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
         } else {
-            ggml_cpy_flt_cuda<float, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+            ggml_cpy_scalar_cuda<float, float>
+                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
         }
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) {
         if (contiguous_srcs) {
-            ggml_cpy_flt_contiguous_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, main_stream);
+            ggml_cpy_scalar_contiguous_cuda<float, nv_bfloat16>
+                (src0_ddc, src1_ddc, ne, main_stream);
         } else {
-            ggml_cpy_flt_cuda<float, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+            ggml_cpy_scalar_cuda<float, nv_bfloat16>
+                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
         }
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
         if (contiguous_srcs) {
-            ggml_cpy_flt_contiguous_cuda<float, half>        (src0_ddc, src1_ddc, ne, main_stream);
+            ggml_cpy_scalar_contiguous_cuda<float, half>
+                (src0_ddc, src1_ddc, ne, main_stream);
         } else {
-            ggml_cpy_flt_cuda<float, half>        (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+            ggml_cpy_scalar_cuda<float, half>
+                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
         }
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
-        ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_f32_q8_0_cuda
+                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) {
-        ggml_cpy_q8_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_q8_0_f32_cuda
+                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
-        ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_f32_q4_0_cuda
+                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_Q4_0 && src1->type == GGML_TYPE_F32) {
-        ggml_cpy_q4_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
-            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_q4_0_f32_cuda
+                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
-        ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_f32_q4_1_cuda
+                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_Q4_1 && src1->type == GGML_TYPE_F32) {
-        ggml_cpy_q4_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
-            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_q4_1_f32_cuda
+                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
-        ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_f32_q5_0_cuda
+                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_Q5_0 && src1->type == GGML_TYPE_F32) {
-        ggml_cpy_q5_0_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02,
-            nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_q5_0_f32_cuda
+                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
-        ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_f32_iq4_nl_cuda
+                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
-        ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_f32_q5_1_cuda
+                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) {
-        ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        ggml_cpy_q5_1_f32_cuda
+                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
         if (can_be_transposed) {
-            ggml_cpy_flt_cuda<half, half, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+            ggml_cpy_scalar_cuda<half, half, true>
+                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
         } else {
-            ggml_cpy_flt_cuda<half, half>       (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+            ggml_cpy_scalar_cuda<half, half>
+                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
         }
     } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) {
         if (contiguous_srcs) {
-            ggml_cpy_flt_contiguous_cuda<half, nv_bfloat16>  (src0_ddc, src1_ddc, ne, main_stream);
+            ggml_cpy_scalar_contiguous_cuda<half, nv_bfloat16>
+                (src0_ddc, src1_ddc, ne, main_stream);
         } else {
-            ggml_cpy_flt_cuda<half, nv_bfloat16>    (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+            ggml_cpy_scalar_cuda<half, nv_bfloat16>
+                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
         }
     } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
         if (contiguous_srcs) {
-            ggml_cpy_flt_contiguous_cuda<half, float>        (src0_ddc, src1_ddc, ne, main_stream);
+            ggml_cpy_scalar_contiguous_cuda<half, float>
+                (src0_ddc, src1_ddc, ne, main_stream);
         } else {
-            ggml_cpy_flt_cuda<half, float>          (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+            ggml_cpy_scalar_cuda<half, float>
+                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
         }
     } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) {
         if (can_be_transposed) {
-            ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16, true> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+            ggml_cpy_scalar_cuda<nv_bfloat16, nv_bfloat16, true>
+                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
         } else {
-            ggml_cpy_flt_cuda<nv_bfloat16, nv_bfloat16> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+            ggml_cpy_scalar_cuda<nv_bfloat16, nv_bfloat16>
+                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
         }
     } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {
         if (contiguous_srcs) {
-            ggml_cpy_flt_contiguous_cuda<nv_bfloat16, half>  (src0_ddc, src1_ddc, ne, main_stream);
+            ggml_cpy_scalar_contiguous_cuda<nv_bfloat16, half>
+                (src0_ddc, src1_ddc, ne, main_stream);
         } else {
-            ggml_cpy_flt_cuda<nv_bfloat16, half>    (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+            ggml_cpy_scalar_cuda<nv_bfloat16, half>
+                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
         }
     } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
         if (contiguous_srcs) {
-            ggml_cpy_flt_contiguous_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, main_stream);
+            ggml_cpy_scalar_contiguous_cuda<nv_bfloat16, float>
+                (src0_ddc, src1_ddc, ne, main_stream);
+        } else {
+            ggml_cpy_scalar_cuda<nv_bfloat16, float>
+                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+        }
+    } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
+        if (can_be_transposed) {
+            ggml_cpy_scalar_cuda<int32_t, int32_t, true>
+                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
         } else {
-            ggml_cpy_flt_cuda<nv_bfloat16, float>   (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+            ggml_cpy_scalar_cuda<int32_t, int32_t>
+                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
         }
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
         if (contiguous_srcs) {
-            ggml_cpy_flt_contiguous_cuda<float, int32_t>     (src0_ddc, src1_ddc, ne, main_stream);
+            ggml_cpy_scalar_contiguous_cuda<float, int32_t>
+                (src0_ddc, src1_ddc, ne, main_stream);
         } else {
-            ggml_cpy_flt_cuda<float, int32_t>       (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+            ggml_cpy_scalar_cuda<float, int32_t>
+                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
         }
     } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
         if (contiguous_srcs) {
-            ggml_cpy_flt_contiguous_cuda<int32_t, float>     (src0_ddc, src1_ddc, ne, main_stream);
+            ggml_cpy_scalar_contiguous_cuda<int32_t, float>
+                (src0_ddc, src1_ddc, ne, main_stream);
         } else {
-            ggml_cpy_flt_cuda<int32_t, float>       (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+            ggml_cpy_scalar_cuda<int32_t, float>
+                (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
         }
     } else {
         GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
index 8fe0899bb5aac95ebecfca8f75fd46aab0d89015..0b29074f33d75fbc3b5962dee6d25380c89e7093 100644 (file)
@@ -4115,6 +4115,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                 if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32) {
                     return true;
                 }
+                if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_I32) {
+                    return true;
+                }
                 if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) {
                     return true;
                 }