]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: support for mat. mul. with ne03 != ne13 (#11656)
authorJohannes Gäßler <redacted>
Wed, 5 Feb 2025 07:58:31 +0000 (08:58 +0100)
committerGitHub <redacted>
Wed, 5 Feb 2025 07:58:31 +0000 (08:58 +0100)
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/mmv.cu

index 70a5980998c4ab4b0e87254d34197a30b5529e0a..4dbaefdbafdf483b59f97c8c6098af96a3b3bda9 100644 (file)
@@ -1366,8 +1366,6 @@ static void ggml_cuda_op_mul_mat(
     const int64_t ne13 = src1->ne[3];
     const int64_t nrows1 = ggml_nrows(src1);
 
-    GGML_ASSERT(ne03 == ne13);
-
     const int64_t ne0 = dst->ne[0];
     const int64_t ne1 = dst->ne[1];
 
@@ -1381,9 +1379,11 @@ static void ggml_cuda_op_mul_mat(
 
     GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
 
-    GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);
+    GGML_ASSERT(ne12 % ne02 == 0);
+    GGML_ASSERT(ne13 % ne03 == 0);
 
     const int64_t i02_divisor = ne12 / ne02;
+    const int64_t i03_divisor = ne13 / ne03;
 
     const size_t src0_ts = ggml_type_size(src0->type);
     const size_t src0_bs = ggml_blck_size(src0->type);
@@ -1399,6 +1399,7 @@ static void ggml_cuda_op_mul_mat(
     GGML_ASSERT(!(split && ne02 > 1));
     GGML_ASSERT(!(split && ne03 > 1));
     GGML_ASSERT(!(split && ne02 < ne12));
+    GGML_ASSERT(!(split && ne03 < ne13));
 
     ggml_tensor_extra_gpu * src0_extra = split ? (ggml_tensor_extra_gpu *) src0->extra : nullptr;
 
@@ -1562,7 +1563,8 @@ static void ggml_cuda_op_mul_mat(
                 }
 
                 // for split tensors the data begins at i0 == i0_offset_low
-                char  *  src0_dd_i =  dev[id].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;
+                const size_t nbytes_src0_matrix = ne01*ne00*src0_ts / src0_bs;
+                char  *  src0_dd_i =  dev[id].src0_dd + ((i03/i03_divisor)*ne02 + (i02/i02_divisor)) * nbytes_src0_matrix;
                 float * src1_ddf_i = dev[id].src1_ddf + (i0*ne11 + src1_col_0) * ne10;
                 char  * src1_ddq_i = dev[id].src1_ddq +  src1_ddq_i_offset;
                 float *   dst_dd_i =   dev[id].dst_dd + (i0*ne1  + src1_col_0) * (dst_on_device ? ne0 : row_diff);
@@ -1606,8 +1608,9 @@ static void ggml_cuda_op_mul_mat(
                     CUDA_CHECK(cudaGetLastError());
                 }
 
-                if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0) {
-                    CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, dev[id].row_low, dev[id].row_high, stream));
+                if (src1_col_0 == 0 && !src0_is_contiguous && i03 % i03_divisor == 0 && i02 % i02_divisor == 0) {
+                    CUDA_CHECK(ggml_cuda_cpy_tensor_2d(
+                        src0_dd_i, src0, i03/i03_divisor, i02/i02_divisor, dev[id].row_low, dev[id].row_high, stream));
                 }
 
                 // do the computation
@@ -1882,7 +1885,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
     //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
     //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
 
-    if (!split && use_mul_mat_vec && dst->ne[3] == 1 && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
+    if (!split && use_mul_mat_vec && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
         // the custom F16 vector kernel can be used over batched cuBLAS GEMM
         // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
         ggml_cuda_mul_mat_vec(ctx, src0, src1, dst);
@@ -2216,12 +2219,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
             ggml_cuda_op_rms_norm_back(ctx, dst);
             break;
         case GGML_OP_MUL_MAT:
-            if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
-                GGML_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]);
-                return false;
-            } else {
-                ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst);
-            }
+            ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst);
             break;
         case GGML_OP_MUL_MAT_ID:
             ggml_cuda_mul_mat_id(ctx, dst);
@@ -2998,9 +2996,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                 if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
                     return false;
                 }
-                if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) {
-                    return false;
-                }
 #ifdef GGML_USE_MUSA
                 if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
                     !ggml_is_transposed(a) && !ggml_is_transposed(b)) {
index 5a9ddd9580ad406c9cdc79184c7f3ca5eb57ced2..f89ed03b578d1ba2a7f209065f0b4664b0de3636 100644 (file)
@@ -1,18 +1,21 @@
+#include "ggml.h"
 #include "common.cuh"
 #include "mmv.cuh"
 
 template <typename T, typename type_acc, int block_size>
 static __global__ void mul_mat_vec(
         const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
-        const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) {
+        const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
+        const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) {
     const int64_t row       = blockIdx.x;
-    const int64_t channel   = blockIdx.z;
+    const int64_t channel   = blockIdx.y;
+    const int64_t sample    = blockIdx.z;
     const int     tid       = threadIdx.x;
     constexpr int warp_size = ggml_cuda_get_physical_warp_size();
 
-    x   += (channel/channel_ratio)*stride_channel_x + row*stride_row;
-    y   +=  channel               *stride_channel_y;
-    dst +=  channel               *stride_channel_dst;
+    x   +=  (sample/sample_ratio)*stride_sample_x   + (channel/channel_ratio)*stride_channel_x + row*stride_row;
+    y   +=   sample              *stride_sample_y   +  channel               *stride_channel_y;
+    dst +=   sample              *stride_sample_dst +  channel               *stride_channel_dst;
 
     const float2 * y2 = (const float2 *) y;
 
@@ -91,12 +94,15 @@ template <typename T, typename type_acc>
 static void launch_mul_mat_vec_cuda(
         const T * x, const float * y, float * dst,
         const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
-        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
+        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
+        const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
         cudaStream_t stream) {
     GGML_ASSERT(ncols      % 2 == 0);
     GGML_ASSERT(stride_row % 2 == 0);
     GGML_ASSERT(nchannels_y % nchannels_x == 0);
+    GGML_ASSERT(nsamples_y  % nsamples_x  == 0);
     const int64_t channel_ratio = nchannels_y / nchannels_x;
+    const int64_t sample_ratio  = nsamples_y  / nsamples_x;
     int device;
     int warp_size;
 
@@ -118,40 +124,48 @@ static void launch_mul_mat_vec_cuda(
     }
 
     const int smem = warp_size*sizeof(float);
-    const dim3 block_nums(nrows, 1, nchannels_y);
+    const dim3 block_nums(nrows, nchannels_y, nsamples_y);
     const dim3 block_dims(block_size_best, 1, 1);
     switch (block_size_best) {
         case   32: {
             mul_mat_vec<T, type_acc,  32><<<block_nums, block_dims, smem, stream>>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case   64: {
             mul_mat_vec<T, type_acc,  64><<<block_nums, block_dims, smem, stream>>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case   96: {
             mul_mat_vec<T, type_acc,  96><<<block_nums, block_dims, smem, stream>>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  128: {
             mul_mat_vec<T, type_acc, 128><<<block_nums, block_dims, smem, stream>>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  160: {
             mul_mat_vec<T, type_acc, 160><<<block_nums, block_dims, smem, stream>>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  192: {
             mul_mat_vec<T, type_acc, 192><<<block_nums, block_dims, smem, stream>>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  224: {
             mul_mat_vec<T, type_acc, 224><<<block_nums, block_dims, smem, stream>>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  256: {
             mul_mat_vec<T, type_acc, 256><<<block_nums, block_dims, smem, stream>>>
-                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
+                (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         default: {
             GGML_ABORT("fatal error");
@@ -163,16 +177,19 @@ template<typename T>
 static void mul_mat_vec_cuda(
         const T * x, const float * y, float * dst,
         const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
-        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
+        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
+        const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
         enum ggml_prec prec, cudaStream_t stream) {
     switch (prec) {
         case GGML_PREC_DEFAULT: {
-            launch_mul_mat_vec_cuda<T, half>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
-                stride_channel_x, stride_channel_y, stride_channel_dst, stream);
+            launch_mul_mat_vec_cuda<T, half>
+                (x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
         } break;
         case GGML_PREC_F32: {
-            launch_mul_mat_vec_cuda<T, float>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
-                stride_channel_x, stride_channel_y, stride_channel_dst, stream);
+            launch_mul_mat_vec_cuda<T, float>
+                (x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
         } break;
     }
 }
@@ -181,10 +198,19 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
     GGML_ASSERT(src1->type == GGML_TYPE_F32);
     GGML_ASSERT(dst->type  == GGML_TYPE_F32);
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
+    GGML_TENSOR_BINARY_OP_LOCALS;
+
+    const size_t ts_src0 = ggml_type_size(src0->type);
+    const size_t ts_src1 = ggml_type_size(src1->type);
+    const size_t ts_dst  = ggml_type_size(dst->type);
+
+    GGML_ASSERT(ne11 == 1);
+    GGML_ASSERT(ne12 == ne2);
+    GGML_ASSERT(ne13 == ne3);
 
-    GGML_ASSERT(src1->ne[1] == 1);
+    GGML_ASSERT(nb00 == ts_src0);
+    GGML_ASSERT(nb10 == ts_src1);
+    GGML_ASSERT(nb0  == ts_dst);
 
     const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
     const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
@@ -192,29 +218,22 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
     const float * src1_d = (const float *) src1->data;
     float       *  dst_d = (float       *)  dst->data;
 
-    const int64_t ne02 = src0->ne[2];
-    const int64_t ne12 = src1->ne[2];
-    GGML_ASSERT(dst->ne[2] == ne12);
-
-    GGML_ASSERT(src0->ne[3] == 1);
-    GGML_ASSERT(src1->ne[3] == 1);
-    GGML_ASSERT( dst->ne[3] == 1);
-
-    const int64_t stride_row         = src0->nb[1] / ggml_type_size(src0->type);
-    const int64_t channel_stride_x   = src0->nb[2] / ggml_type_size(src0->type);
-    const int64_t channel_stride_y   = src1->nb[2] / ggml_type_size(src1->type);
-    const int64_t channel_stride_dst =  dst->nb[2] / ggml_type_size( dst->type);
+    const int64_t s01 = src0->nb[1] / ts_src0;
+    const int64_t s02 = src0->nb[2] / ts_src0;
+    const int64_t s12 = src1->nb[2] / ts_src1;
+    const int64_t s2  =  dst->nb[2] / ts_dst;
+    const int64_t s03 = src0->nb[3] / ts_src0;
+    const int64_t s13 = src1->nb[3] / ts_src1;
+    const int64_t s3  =  dst->nb[3] / ts_dst;
 
     switch (src0->type) {
         case GGML_TYPE_F16: {
             const half * src0_d = (const half *) src0->data;
-            mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
-                channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
+            mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream());
         } break;
         case GGML_TYPE_BF16: {
             const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
-            mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
-                channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
+            mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream());
         } break;
         default:
             GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
@@ -243,20 +262,27 @@ void ggml_cuda_op_mul_mat_vec(
     const int64_t stride_row         = ne00;
     const int64_t nchannels_x        = 1;
     const int64_t nchannels_y        = 1;
-    const int64_t channel_stride_x   = 0;
-    const int64_t channel_stride_y   = 0;
-    const int64_t channel_stride_dst = 0;
+    const int64_t stride_channel_x   = 0;
+    const int64_t stride_channel_y   = 0;
+    const int64_t stride_channel_dst = 0;
+    const int64_t nsamples_x         = 1;
+    const int64_t nsamples_y         = 1;
+    const int64_t stride_sample_x    = 0;
+    const int64_t stride_sample_y    = 0;
+    const int64_t stride_sample_dst  = 0;
 
     switch (src0->type) {
         case GGML_TYPE_F16: {
             const half * src0_d = (const half *) src0_dd_i;
             mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
-                nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
+                nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
         } break;
         case GGML_TYPE_BF16: {
             const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
             mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
-                nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
+                nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
         } break;
         default:
             GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));