]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: fix non-cont. inputs for batched mat mul (llama/13155)
authorJohannes Gäßler <redacted>
Tue, 29 Apr 2025 14:00:27 +0000 (16:00 +0200)
committerGeorgi Gerganov <redacted>
Thu, 1 May 2025 10:29:02 +0000 (13:29 +0300)
ggml/src/ggml-cuda/convert.cu
ggml/src/ggml-cuda/convert.cuh
ggml/src/ggml-cuda/ggml-cuda.cu

index a224ec0e12ddb1bdfa89de4da382f677c8f87998..c6dec4276b36d0444147479c578c8d81cfc8bac5 100644 (file)
@@ -1,6 +1,8 @@
 #include "convert.cuh"
 #include "dequantize.cuh"
 
+#include <cstdint>
+
 #define CUDA_Q8_0_NE_ALIGN 2048
 
 template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
@@ -570,30 +572,46 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t
 }
 
 template <typename src_t, typename dst_t>
-static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) {
-    const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
+static __global__ void convert_unary(
+        const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
+        const int64_t s01, const int64_t s02, const int64_t s03) {
+    const int64_t i00 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
 
-    if (i >= k) {
+    if (i00 >= ne00) {
         return;
     }
 
+    const int64_t i01 = blockIdx.y;
+    const int64_t i02 = blockIdx.z % ne02;
+    const int64_t i03 = blockIdx.z / ne02;
+
     const src_t * x = (const src_t *) vx;
 
-    y[i] = float(x[i]);
+    const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
+    const int64_t iy = ((i03*ne02 + i02)*ne01 + i01)*ne00 + i00;
+    y[iy] = float(x[ix]);
 }
 
 template <typename src_t, typename dst_t>
-static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) {
-    const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
-    convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
+static void convert_unary_cuda(const void * vx, dst_t * y,
+        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
+        const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
+    const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, ne02*ne03);
+    convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
+        (vx, y, ne00, ne01, ne02, s01, s02, s03);
+}
+
+template <typename src_t, typename dst_t>
+static void convert_unary_cont_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+    convert_unary_cuda<src_t>(vx, y, k, 1, 1, 1, k, k, k, stream);
 }
 
 to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type) {
     switch (type) {
         case GGML_TYPE_F32:
-            return convert_unary_cuda<float>;
+            return convert_unary_cont_cuda<float>;
         case GGML_TYPE_F16:
-            return convert_unary_cuda<half>;
+            return convert_unary_cont_cuda<half>;
         default:
             return nullptr;
     }
@@ -643,9 +661,9 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
         case GGML_TYPE_IQ3_S:
             return dequantize_row_iq3_s_cuda;
         case GGML_TYPE_F32:
-            return convert_unary_cuda<float>;
+            return convert_unary_cont_cuda<float>;
         case GGML_TYPE_BF16:
-            return convert_unary_cuda<nv_bfloat16>;
+            return convert_unary_cont_cuda<nv_bfloat16>;
         default:
             return nullptr;
     }
@@ -692,7 +710,18 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
         case GGML_TYPE_IQ3_S:
             return dequantize_row_iq3_s_cuda;
         case GGML_TYPE_F16:
-            return convert_unary_cuda<half>;
+            return convert_unary_cont_cuda<half>;
+        case GGML_TYPE_BF16:
+            return convert_unary_cont_cuda<nv_bfloat16>;
+        default:
+            return nullptr;
+    }
+}
+
+to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_F32:
+            return convert_unary_cuda<float>;
         case GGML_TYPE_BF16:
             return convert_unary_cuda<nv_bfloat16>;
         default:
index 411a13cf12667f8e846c420343c595d4cd15ba98..b65b98e08e7e2e3868761bf1c31dcb9e638afab5 100644 (file)
@@ -3,7 +3,7 @@
 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256
 
 template<typename T>
-using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int64_t k, cudaStream_t stream);
+using to_t_cuda_t = void (*)(const void * x, T * y, int64_t k, cudaStream_t stream);
 
 typedef to_t_cuda_t<float> to_fp32_cuda_t;
 typedef to_t_cuda_t<half> to_fp16_cuda_t;
@@ -14,3 +14,13 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type);
 to_bf16_cuda_t ggml_get_to_bf16_cuda(ggml_type type);
 
 to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type);
+
+// TODO more general support for non-contiguous inputs
+
+template<typename T>
+using to_t_nc_cuda_t = void (*)(const void * x, T * y,
+    int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03,
+    int64_t s01, int64_t s02, int64_t s03, cudaStream_t stream);
+
+typedef to_t_nc_cuda_t<half> to_fp16_nc_cuda_t;
+to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type);
index 19b9ce7231aa29086b10077961cbe54baa7f0aca..fba8cb6565bae5c8d4f74e09a2471993b4cddf67 100644 (file)
@@ -1720,15 +1720,15 @@ static __global__ void k_compute_batched_ptrs(
         size_t  nb12, size_t  nb13,
         size_t  nbd2, size_t  nbd3,
         int64_t r2,   int64_t r3) {
-    int64_t i13 = blockIdx.x * blockDim.x + threadIdx.x;
-    int64_t i12 = blockIdx.y * blockDim.y + threadIdx.y;
+    const int64_t i13 = blockIdx.x * blockDim.x + threadIdx.x;
+    const int64_t i12 = blockIdx.y * blockDim.y + threadIdx.y;
 
     if (i13 >= ne13 || i12 >= ne12) {
         return;
     }
 
-    int64_t i03 = i13 / r3;
-    int64_t i02 = i12 / r2;
+    const int64_t i03 = i13 / r3;
+    const int64_t i02 = i12 / r2;
 
     ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
     ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13;
@@ -1742,6 +1742,10 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
     GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer));
     GGML_ASSERT(src0->type == GGML_TYPE_F16);
 
+    // Byte offsets and tensor dimensions are currently used in an inconsistent way for dst.
+    // As long as dst is contiguous this does not matter though.
+    GGML_ASSERT(ggml_is_contiguous(dst));
+
     GGML_TENSOR_BINARY_OP_LOCALS
 
     const int64_t ne_dst = ggml_nelements(dst);
@@ -1750,21 +1754,31 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
 
     CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(), main_stream));
 
-    void * src0_ddq = src0->data;
-    half * src0_f16 = (half *) src0_ddq;
-    float * src1_ddf = (float *) src1->data;
-    float * dst_ddf  = (float *) dst->data;
+    const half * src0_f16 = (const half *) src0->data;
+    float * dst_ddf = (float *) dst->data;
 
-    // convert src1 to fp16
+    const half * src1_f16 = (const half *) src1->data;
+    const size_t ts_src1 = ggml_type_size(src1->type);
+    GGML_ASSERT(nb10 == ts_src1);
+    int64_t s11 = nb11 / ts_src1;
+    int64_t s12 = nb12 / ts_src1;
+    int64_t s13 = nb13 / ts_src1;
     ggml_cuda_pool_alloc<half> src1_f16_alloc(ctx.pool());
+
+    // convert src1 to fp16
     if (src1->type != GGML_TYPE_F16) {
-        const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
+        const to_fp16_nc_cuda_t to_fp16_cuda = ggml_get_to_fp16_nc_cuda(src1->type);
         const int64_t ne_src1 = ggml_nelements(src1);
         src1_f16_alloc.alloc(ne_src1);
         GGML_ASSERT(to_fp16_cuda != nullptr);
-        to_fp16_cuda(src1_ddf, src1_f16_alloc.get(), ne_src1, main_stream);
+
+        to_fp16_cuda(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, main_stream);
+
+        src1_f16 = src1_f16_alloc.get();
+        s11 = ne10;
+        s12 = ne11*s11;
+        s13 = ne12*s12;
     }
-    half * src1_f16 = src1->type == GGML_TYPE_F16 ? (half *) src1_ddf : src1_f16_alloc.get();
 
     ggml_cuda_pool_alloc<half> dst_f16(ctx.pool());
     char * dst_t;
@@ -1824,13 +1838,13 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
                 int i02 = i12 / r2;
 
                 CUBLAS_CHECK(
-                        cublasGemmEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
-                            ne01, ne11, ne10,
-                            alpha, (const char *) src0_as_f16 + i02*src0->nb[2]   + i03*src0->nb[3]  , CUDA_R_16F,   nb01/sizeof(half),
-                                   (const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F,   nb11/sizeof(float),
-                            beta,  (      char *)       dst_t + i12*nbd2          + i13*nbd3,          cu_data_type, ne01,
-                            cu_compute_type,
-                            CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+                cublasGemmEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
+                    ne01, ne11, ne10,
+                    alpha, (const char *) src0_f16 + i03*nb03 + i02*nb02, CUDA_R_16F,   nb01/sizeof(half),
+                                          src1_f16 + i13*s13  + i12*s12,  CUDA_R_16F,   s11,
+                    beta,  (      char *)    dst_t + i13*nbd3 + i12*nbd2, cu_data_type, ne0,
+                    cu_compute_type,
+                    CUBLAS_GEMM_DEFAULT_TENSOR_OP));
             }
         }
     }
@@ -1841,15 +1855,15 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
         CUBLAS_CHECK(
         cublasGemmStridedBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
                 ne01, ne11, ne10,
-                alpha, (const char *) src0_f16, CUDA_R_16F,   nb01/nb00, nb02/nb00,  // strideA
-                       (const char *) src1_f16, CUDA_R_16F,   nb11/nb10, nb12/nb10,  // strideB
-                beta,  (      char *)    dst_t, cu_data_type, ne01,       nb2/nb0,   // strideC
+                alpha, src0_f16, CUDA_R_16F,   nb01/nb00, nb02/nb00, // strideA
+                       src1_f16, CUDA_R_16F,   s11,       s12,       // strideB
+                beta,     dst_t, cu_data_type, ne0,       ne1*ne0,   // strideC
                 ne12*ne13,
                 cu_compute_type,
                 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
     } else {
         // use cublasGemmBatchedEx
-        const int ne23 = ne12*ne13;
+        const int64_t ne23 = ne12*ne13;
 
         ggml_cuda_pool_alloc<const void *> ptrs_src(ctx.pool(), 2*ne23);
         ggml_cuda_pool_alloc<      void *> ptrs_dst(ctx.pool(), 1*ne23);
@@ -1861,8 +1875,8 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
                 ne12, ne13,
                 ne23,
                 nb02, nb03,
-                src1->type == GGML_TYPE_F16 ? nb12 : nb12/2,
-                src1->type == GGML_TYPE_F16 ? nb13 : nb13/2,
+                src1->type == GGML_TYPE_F16 ? nb12 : s12*sizeof(half),
+                src1->type == GGML_TYPE_F16 ? nb13 : s13*sizeof(half),
                 nbd2, nbd3,
                 r2, r3);
         CUDA_CHECK(cudaGetLastError());
@@ -1871,8 +1885,8 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
         cublasGemmBatchedEx(ctx.cublas_handle(), CUBLAS_OP_T, CUBLAS_OP_N,
                 ne01, ne11, ne10,
                 alpha, (const void **) (ptrs_src.get() + 0*ne23), CUDA_R_16F,   nb01/nb00,
-                       (const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F,   nb11/nb10,
-                beta,  (      void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne01,
+                       (const void **) (ptrs_src.get() + 1*ne23), CUDA_R_16F,   s11,
+                beta,  (      void **) (ptrs_dst.get() + 0*ne23), cu_data_type, ne0,
                 ne23,
                 cu_compute_type,
                 CUBLAS_GEMM_DEFAULT_TENSOR_OP));
@@ -1936,7 +1950,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
     } else if (!split && use_mul_mat_vec_q) {
         ggml_cuda_mul_mat_vec_q(ctx, src0, src1, nullptr, dst);
     } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) &&
-            dst->op_params[0] == GGML_PREC_DEFAULT && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
+            !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
         // general KQ + KQV multi-batch without FlashAttention
         ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
     } else if (use_mul_mat_vec) {