]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
`ggml_cuda_cpy` support for 4d tensors and float16->float32 upcasting (ggml/686)
authorJohn Balis <redacted>
Mon, 29 Jan 2024 12:37:33 +0000 (06:37 -0600)
committerGeorgi Gerganov <redacted>
Tue, 30 Jan 2024 19:27:59 +0000 (21:27 +0200)
* added cuda float16->float32 upcasting to ggml_cuda_cpy

* added ability to copy 4d tensors with the cuda backend

* added tests for float16_>float32 upcast and 4d tensor cuda copys

* added 4d copy test for float32->float16 copy

* applied patch suggested by @iamlemec

* simplify cpy tests

---------

Co-authored-by: slaren <redacted>
ggml-cuda.cu

index 7695b86b20fb9a85fb9bf5744368d1fe68901b40..8ff9fbd564775e5aac224291944f996dcfaae28f 100644 (file)
@@ -5357,27 +5357,37 @@ static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
     *dsti = *xi;
 }
 
+static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
+    const half * xi = (const half *) cxi;
+    float * dsti = (float *) cdsti;
+
+    *dsti = *xi;
+}
+
 template <cpy_kernel_t cpy_1>
 static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
-                                   const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
-                                   const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) {
+                                   const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+                                   const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
+                                   const int nb12, const int nb13) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
     if (i >= ne) {
         return;
     }
 
-    // determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
+    // determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
     // then combine those indices with the corresponding byte offsets to get the total offsets
-    const int i02 = i / (ne00*ne01);
-    const int i01 = (i - i02*ne01*ne00) / ne00;
-    const int i00 = i - i02*ne01*ne00 - i01*ne00;
-    const int x_offset = i00*nb00 + i01*nb01 + i02*nb02;
-
-    const int i12 = i / (ne10*ne11);
-    const int i11 = (i - i12*ne10*ne11) / ne10;
-    const int i10 = i - i12*ne10*ne11 - i11*ne10;
-    const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12;
+    const int i03 = i/(ne00 * ne01 * ne02);
+    const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
+    const int i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
+    const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
+    const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
+
+    const int i13 = i/(ne10 * ne11 * ne12);
+    const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
+    const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
+    const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
+    const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
 
     cpy_1(cx + x_offset, cdst + dst_offset);
 }
@@ -5471,23 +5481,26 @@ static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
 
 template <cpy_kernel_t cpy_blck, int qk>
 static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
-                                 const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
-                                 const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) {
+                                 const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+                                 const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
+                                 const int nb12, const int nb13) {
     const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
 
     if (i >= ne) {
         return;
     }
 
-    const int i02 = i / (ne00*ne01);
-    const int i01 = (i - i02*ne01*ne00) / ne00;
-    const int i00 = (i - i02*ne01*ne00 - i01*ne00);
-    const int x_offset = i00*nb00 + i01*nb01 + i02*nb02;
+    const int i03 = i/(ne00 * ne01 * ne02);
+    const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
+    const int i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
+    const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
+    const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
 
-    const int i12 = i / (ne10*ne11);
-    const int i11 = (i - i12*ne10*ne11) / ne10;
-    const int i10 = (i - i12*ne10*ne11 - i11*ne10)/qk;
-    const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12;
+    const int i13 = i/(ne10 * ne11 * ne12);
+    const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
+    const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
+    const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
+    const int dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
 
     cpy_blck(cx + x_offset, cdst + dst_offset);
 }
@@ -7135,69 +7148,82 @@ static void ggml_mul_mat_vec_nc_f16_f32_cuda(
         (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x, nchannels_y/nchannels_x);
 }
 
+
+static void ggml_cpy_f16_f32_cuda(
+    const char * cx, char * cdst, const int ne,
+    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
+
+    const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
+    cpy_f32_f16<cpy_1_f16_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
+}
+
 static void ggml_cpy_f32_f32_cuda(
     const char * cx, char * cdst, const int ne,
-    const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
-    const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
+    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
 
     const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
     cpy_f32_f16<cpy_1_f32_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
-        (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
 static void ggml_cpy_f32_f16_cuda(
     const char * cx, char * cdst, const int ne,
-    const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
-    const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
+    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
 
     const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
     cpy_f32_f16<cpy_1_f32_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
-        (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
 static void ggml_cpy_f32_q8_0_cuda(
     const char * cx, char * cdst, const int ne,
-    const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
-    const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
+    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
 
     GGML_ASSERT(ne % QK8_0 == 0);
     const int num_blocks = ne / QK8_0;
     cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
-        (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
 static void ggml_cpy_f32_q4_0_cuda(
     const char * cx, char * cdst, const int ne,
-    const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
-    const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
+    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
 
     GGML_ASSERT(ne % QK4_0 == 0);
     const int num_blocks = ne / QK4_0;
     cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
-        (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
 static void ggml_cpy_f32_q4_1_cuda(
     const char * cx, char * cdst, const int ne,
-    const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
-    const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
+    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
 
     GGML_ASSERT(ne % QK4_1 == 0);
     const int num_blocks = ne / QK4_1;
     cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
-        (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
 static void ggml_cpy_f16_f16_cuda(
     const char * cx, char * cdst, const int ne,
-    const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
-    const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
+    const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+    const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
 
     const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
     cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
-        (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
+        (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
 }
 
+
+
 static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
     scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
@@ -9941,19 +9967,25 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
 
     const int64_t ne00 = src0->ne[0];
     const int64_t ne01 = src0->ne[1];
-    GGML_ASSERT(src0->ne[3] == 1);
+    const int64_t ne02 = src0->ne[2];
+    
+    //GGML_ASSERT(src0->ne[3] == 1);
 
     const int64_t nb00 = src0->nb[0];
     const int64_t nb01 = src0->nb[1];
     const int64_t nb02 = src0->nb[2];
+    const int64_t nb03 = src0->nb[3];
 
     const int64_t ne10 = src1->ne[0];
     const int64_t ne11 = src1->ne[1];
-    GGML_ASSERT(src1->ne[3] == 1);
+    const int64_t ne12 = src1->ne[2];
+
+    //GGML_ASSERT(src1->ne[3] == 1);
 
     const int64_t nb10 = src1->nb[0];
     const int64_t nb11 = src1->nb[1];
     const int64_t nb12 = src1->nb[2];
+    const int64_t nb13 = src1->nb[3];
 
     ggml_cuda_set_device(g_main_device);
     cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
@@ -9965,17 +9997,19 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
     char * src1_ddc = (char *) src1_extra->data_device[g_main_device];
 
     if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
-        ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
+        ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
-        ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
+        ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
-        ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
+        ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
-        ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
+        ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
-        ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
+        ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
-        ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12, main_stream);
+        ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
+        ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
     } else {
         fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
                 ggml_type_name(src0->type), ggml_type_name(src1->type));
@@ -10978,6 +11012,9 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
                 if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
                     return true;
                 }
+                if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
+                    return true;
+                }
                 return false;
             } break;
         case GGML_OP_DUP: