]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: mul_mat_v support for batch sizes > 1 (#14262)
authorJohannes Gäßler <redacted>
Mon, 23 Jun 2025 11:11:31 +0000 (13:11 +0200)
committerGitHub <redacted>
Mon, 23 Jun 2025 11:11:31 +0000 (13:11 +0200)
* CUDA: mul_mat_v support for batch sizes > 1

* use 64 bit math for initial offset calculation

ggml/src/ggml-cuda/common.cuh
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/mmv.cu
ggml/src/ggml-cuda/mmv.cuh

index 86c4d29a5d254a8832a996e9051c9f3ca31521a4..1369bc2d9e5e3bd04e054501b47718918765781f 100644 (file)
@@ -262,6 +262,10 @@ static bool fp16_mma_hardware_available(const int cc) {
         GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
 }
 
+static bool bf16_mma_hardware_available(const int cc) {
+    return GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE;
+}
+
 // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
 static bool new_mma_available(const int cc) {
     return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
index 462db71e1a610a963938f3c1b32057d988779a46..b3e6833c396fdee87b4bb324ebf667228339a116 100644 (file)
@@ -1943,16 +1943,14 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
         && ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src;
 
     bool use_mul_mat_vec   = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16)
-        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
-        && src0->ne[0] % 2 == 0 && src1->ne[1] == 1;
+        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
     bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear
         && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
         && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
     bool use_mul_mat_q     = ggml_is_quantized(src0->type) && !bad_padding_clear
         && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
 
-    bool any_gpus_with_slow_fp16   = false;
-    bool any_gpus_without_fp16_mma = false;
+    bool any_gpus_with_slow_fp16 = false;
 
     if (split) {
         ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context;
@@ -1963,16 +1961,16 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
                 continue;
             }
 
-            const int cc              = ggml_cuda_info().devices[id].cc;
-            use_mul_mat_q             = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
-            any_gpus_with_slow_fp16   = any_gpus_with_slow_fp16   || !fast_fp16_hardware_available(cc);
-            any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
+            const int cc            = ggml_cuda_info().devices[id].cc;
+            use_mul_mat_q           = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
+            use_mul_mat_vec         = use_mul_mat_vec           && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
+            any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16   || !fast_fp16_hardware_available(cc);
         }
     } else {
-        const int cc              = ggml_cuda_info().devices[ctx.device].cc;
-        use_mul_mat_q             = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
-        any_gpus_with_slow_fp16   = any_gpus_with_slow_fp16   || !fast_fp16_hardware_available(cc);
-        any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
+        const int cc            = ggml_cuda_info().devices[ctx.device].cc;
+        use_mul_mat_q           = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
+        use_mul_mat_vec         = use_mul_mat_vec           && ggml_cuda_should_use_mmv(src0->type, cc, src0->ne, src1->ne[1]);
+        any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16   || !fast_fp16_hardware_available(cc);
     }
 
     // debug helpers
@@ -1983,7 +1981,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
     //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
     //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
 
-    if (!split && use_mul_mat_vec && (src0->ne[1] <= MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
+    if (!split && use_mul_mat_vec) {
         // the custom F16 vector kernel can be used over batched cuBLAS GEMM
         // but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
         ggml_cuda_mul_mat_vec(ctx, src0, src1, nullptr, dst);
index d8c385e2399aea4bbd6354d89d06f680eda6f58f..1502e9d942fbcc700843d338f93deadd8e8a7d8b 100644 (file)
@@ -2,25 +2,26 @@
 #include "common.cuh"
 #include "mmv.cuh"
 
-template <typename T, typename type_acc, int block_size>
+template <typename T, typename type_acc, int ncols_dst, int block_size>
 static __global__ void mul_mat_vec(
         const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
-        const int64_t ncols2, const int64_t nchannels_y, const int64_t stride_row,
-        const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
-        const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) {
-    const int64_t row         = blockIdx.x;
-    const int64_t channel_dst = blockIdx.y;
-    const int64_t channel_x   = ids ? ids[channel_dst]          : channel_dst / channel_ratio;
-    const int64_t channel_y   = ids ? channel_dst % nchannels_y : channel_dst;
-    const int64_t sample_dst  = blockIdx.z;
-    const int64_t sample_x    = sample_dst / sample_ratio;
-    const int64_t sample_y    = sample_dst;
-    const int     tid         = threadIdx.x;
+        const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
+        const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
+        const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
+    const int row         = blockIdx.x;
+    const int channel_dst = blockIdx.y;
+    const int channel_x   = ids ? ids[channel_dst]          : channel_dst / channel_ratio;
+    const int channel_y   = ids ? channel_dst % nchannels_y : channel_dst;
+    const int sample_dst  = blockIdx.z;
+    const int sample_x    = sample_dst / sample_ratio;
+    const int sample_y    = sample_dst;
+    const int tid         = threadIdx.x;
+
     constexpr int warp_size   = ggml_cuda_get_physical_warp_size();
 
-    x   += sample_x  *stride_sample_x   + channel_x  *stride_channel_x   + row*stride_row;
-    y   += sample_y  *stride_sample_y   + channel_y  *stride_channel_y;
-    dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst;
+    x   += int64_t(sample_x)  *stride_sample_x   + channel_x  *stride_channel_x   + row*stride_row;
+    y   += int64_t(sample_y)  *stride_sample_y   + channel_y  *stride_channel_y;
+    dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
 
     const float2 * y2 = (const float2 *) y;
 
@@ -34,81 +35,108 @@ static __global__ void mul_mat_vec(
         __syncthreads();
     }
 
-    float sumf = 0.0f;
+    float sumf[ncols_dst] = {0.0f};
 
     if constexpr (std::is_same<T, float>::value) {
         const float2 * x2 = (const float2 *) x;
 
-        for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
+        for (int col2 = tid; col2 < ncols2; col2 += block_size) {
             const float2 tmpx = x2[col2];
-            const float2 tmpy = y2[col2];
-            sumf += tmpx.x*tmpy.x;
-            sumf += tmpx.y*tmpy.y;
+
+#pragma unroll
+            for (int j = 0; j < ncols_dst; ++j) {
+                const float2 tmpy = y2[j*stride_col_y2 + col2];
+                sumf[j] += tmpx.x*tmpy.x;
+                sumf[j] += tmpx.y*tmpy.y;
+            }
         }
     } else if constexpr (std::is_same<T, half>::value) {
         const half2 * x2 = (const half2 *) x;
 
         if (std::is_same<type_acc, float>::value) {
-            for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
+            for (int col2 = tid; col2 < ncols2; col2 += block_size) {
                 const float2 tmpx = __half22float2(x2[col2]);
-                const float2 tmpy = y2[col2];
-                sumf += tmpx.x * tmpy.x;
-                sumf += tmpx.y * tmpy.y;
+
+#pragma unroll
+                for (int j = 0; j < ncols_dst; ++j) {
+                    const float2 tmpy = y2[j*stride_col_y2 + col2];
+                    sumf[j] += tmpx.x * tmpy.x;
+                    sumf[j] += tmpx.y * tmpy.y;
+                }
             }
         } else {
 #ifdef FP16_AVAILABLE
-            half2 sumh2 = make_half2(0.0f, 0.0f);
+            half2 sumh2[ncols_dst] = {{0.0f, 0.0f}};
+
+            for (int col2 = tid; col2 < ncols2; col2 += block_size) {
+                const half2 tmpx = x2[col2];
 
-            for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
-                const float2 tmp = y2[col2];
-                sumh2 += x2[col2] * make_half2(tmp.x, tmp.y);
+#pragma unroll
+                for (int j = 0; j < ncols_dst; ++j) {
+                    const float2 tmpy = y2[j*stride_col_y2 + col2];
+                    sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y);
+                }
             }
 
-            sumf = __low2float(sumh2) + __high2float(sumh2);
+#pragma unroll
+            for (int j = 0; j < ncols_dst; ++j) {
+                sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]);
+            }
 #else
             NO_DEVICE_CODE;
 #endif // FP16_AVAILABLE
         }
     } else if constexpr (std::is_same<T, nv_bfloat16>::value) {
         const int * x2 = (const int *) x;
-        for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
-            const int    tmpx = x2[col2];
-            const float2 tmpy = y2[col2];
-            sumf += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
-            sumf += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
+        for (int col2 = tid; col2 < ncols2; col2 += block_size) {
+            const int tmpx = x2[col2];
+#pragma unroll
+            for (int j = 0; j < ncols_dst; ++j) {
+                const float2 tmpy = y2[j*stride_col_y2 + col2];
+                sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
+                sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
+            }
         }
     } else {
         static_assert(std::is_same<T, void>::value, "unsupported type");
     }
 
-    sumf = warp_reduce_sum<warp_size>(sumf);
+#pragma unroll
+    for (int j = 0; j < ncols_dst; ++j) {
+        sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
 
-    if (block_size > warp_size) {
-        buf_iw[tid/warp_size] = sumf;
-        __syncthreads();
-        if (tid >= warp_size) {
-            return;
+        if (block_size > warp_size) {
+            buf_iw[tid/warp_size] = sumf[j];
+            __syncthreads();
+            if (tid < warp_size) {
+                sumf[j] = buf_iw[tid];
+                sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
+            }
+            if (j < ncols_dst) {
+                __syncthreads();
+            }
         }
-        sumf = buf_iw[tid];
-        sumf = warp_reduce_sum<warp_size>(sumf);
     }
 
-    if (tid != 0) {
+    if (tid >= ncols_dst) {
         return;
     }
 
-    dst[row] = sumf;
+    dst[tid*stride_col_dst + row] = sumf[tid];
 }
 
-template <typename T, typename type_acc>
+template <typename T, typename type_acc, int ncols_dst>
 static void launch_mul_mat_vec_cuda(
         const T * x, const float * y, const int32_t * ids, float * dst,
-        const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
+        const int64_t ncols, const int64_t nrows,
+        const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
+        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
         const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
         cudaStream_t stream) {
-    GGML_ASSERT(ncols      % 2 == 0);
-    GGML_ASSERT(stride_row % 2 == 0);
+    GGML_ASSERT(ncols        % 2 == 0);
+    GGML_ASSERT(stride_row   % 2 == 0);
+    GGML_ASSERT(stride_col_y % 2 == 0);
     GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
     GGML_ASSERT(       nsamples_dst  % nsamples_x  == 0);
     const int64_t channel_ratio = nchannels_dst / nchannels_x;
@@ -138,44 +166,52 @@ static void launch_mul_mat_vec_cuda(
     const dim3 block_dims(block_size_best, 1, 1);
     switch (block_size_best) {
         case   32: {
-            mul_mat_vec<T, type_acc,  32><<<block_nums, block_dims, smem, stream>>>
-                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+            mul_mat_vec<T, type_acc, ncols_dst,  32><<<block_nums, block_dims, smem, stream>>>
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case   64: {
-            mul_mat_vec<T, type_acc,  64><<<block_nums, block_dims, smem, stream>>>
-                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+            mul_mat_vec<T, type_acc, ncols_dst,  64><<<block_nums, block_dims, smem, stream>>>
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case   96: {
-            mul_mat_vec<T, type_acc,  96><<<block_nums, block_dims, smem, stream>>>
-                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+            mul_mat_vec<T, type_acc, ncols_dst,  96><<<block_nums, block_dims, smem, stream>>>
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  128: {
-            mul_mat_vec<T, type_acc, 128><<<block_nums, block_dims, smem, stream>>>
-                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+            mul_mat_vec<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, smem, stream>>>
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  160: {
-            mul_mat_vec<T, type_acc, 160><<<block_nums, block_dims, smem, stream>>>
-                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+            mul_mat_vec<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, smem, stream>>>
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  192: {
-            mul_mat_vec<T, type_acc, 192><<<block_nums, block_dims, smem, stream>>>
-                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+            mul_mat_vec<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, smem, stream>>>
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  224: {
-            mul_mat_vec<T, type_acc, 224><<<block_nums, block_dims, smem, stream>>>
-                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+            mul_mat_vec<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, smem, stream>>>
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         case  256: {
-            mul_mat_vec<T, type_acc, 256><<<block_nums, block_dims, smem, stream>>>
-                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+            mul_mat_vec<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, smem, stream>>>
+                (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
         } break;
         default: {
             GGML_ABORT("fatal error");
@@ -183,23 +219,91 @@ static void launch_mul_mat_vec_cuda(
     }
 }
 
+template <typename T, typename type_acc>
+static void mul_mat_vec_cuda_switch_ncols_dst(
+        const T * x, const float * y, const int32_t * ids, float * dst,
+        const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
+        const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
+        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
+        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
+        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+        cudaStream_t stream) {
+    switch (ncols_dst) {
+        case 1:
+            launch_mul_mat_vec_cuda<T, type_acc, 1>
+                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+            break;
+        case 2:
+            launch_mul_mat_vec_cuda<T, type_acc, 2>
+                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+            break;
+        case 3:
+            launch_mul_mat_vec_cuda<T, type_acc, 3>
+                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+            break;
+        case 4:
+            launch_mul_mat_vec_cuda<T, type_acc, 4>
+                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+            break;
+        case 5:
+            launch_mul_mat_vec_cuda<T, type_acc, 5>
+                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+            break;
+        case 6:
+            launch_mul_mat_vec_cuda<T, type_acc, 6>
+                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+            break;
+        case 7:
+            launch_mul_mat_vec_cuda<T, type_acc, 7>
+                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+            break;
+        case 8:
+            launch_mul_mat_vec_cuda<T, type_acc, 8>
+                (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+            break;
+        default:
+            GGML_ABORT("fatal error");
+            break;
+    }
+}
+
 template<typename T>
 static void mul_mat_vec_cuda(
         const T * x, const float * y, const int32_t * ids, float * dst,
-        const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
+        const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
+        const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst,
+        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
         const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
         enum ggml_prec prec, cudaStream_t stream) {
     if constexpr(std::is_same<T, half>::value) {
         if (prec == GGML_PREC_DEFAULT) {
-            launch_mul_mat_vec_cuda<T, half>
-                (x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+            mul_mat_vec_cuda_switch_ncols_dst<T, half>
+                (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
                  stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
             return;
         }
     }
-    launch_mul_mat_vec_cuda<T, float>
-        (x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+    mul_mat_vec_cuda_switch_ncols_dst<T, float>
+        (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+         nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
          stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
 }
 
@@ -246,24 +350,24 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
     const int64_t stride_channel_dst = ids ? s1   : s2;
     const int64_t stride_channel_y   = ids ? s11  : s12;
 
-    GGML_ASSERT(ncols_dst == 1);
+    GGML_ASSERT(!ids || ncols_dst == 1);
 
     switch (src0->type) {
         case GGML_TYPE_F32: {
             const float * src0_d = (const float *) src0->data;
-            mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
+            mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
                 ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
                 ne03,              ne3,           s03, s13,              s3,                 prec, ctx.stream());
         } break;
         case GGML_TYPE_F16: {
             const half * src0_d = (const half *) src0->data;
-            mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
+            mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
                 ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
                 ne03,              ne3,           s03, s13,              s3,                 prec, ctx.stream());
         } break;
         case GGML_TYPE_BF16: {
             const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
-            mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
+            mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
                 ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
                 ne03,              ne3,           s03, s13,              s3,                 prec, ctx.stream());
         } break;
@@ -282,16 +386,19 @@ void ggml_cuda_op_mul_mat_vec(
     GGML_ASSERT(dst->type  == GGML_TYPE_F32);
 
     const int64_t ne00 = src0->ne[0];
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne0  =  dst->ne[0];
     const int64_t row_diff = row_high - row_low;
 
-    GGML_ASSERT(src1_ncols == 1);
-
-    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+    const int id = ggml_cuda_get_device();
+    const int cc = ggml_cuda_info().devices[id].cc;
     const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
 
 
     // ggml_cuda_op provides single, contiguous matrices
     const int64_t stride_row         = ne00;
+    const int64_t stride_col_y       = ne10;
+    const int64_t stride_col_dst     = id == ctx.device ? ne0 : row_diff; // main device has larger memory buffer
     const int64_t nchannels_x        = 1;
     const int64_t nchannels_y        = 1;
     const int64_t nchannels_dst      = 1;
@@ -307,19 +414,19 @@ void ggml_cuda_op_mul_mat_vec(
     switch (src0->type) {
         case GGML_TYPE_F32: {
             const float * src0_d = (const float *) src0_dd_i;
-            mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
+            mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
         } break;
         case GGML_TYPE_F16: {
             const half * src0_d = (const half *) src0_dd_i;
-            mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
+            mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
         } break;
         case GGML_TYPE_BF16: {
             const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
-            mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
+            mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
         } break;
@@ -334,3 +441,48 @@ void ggml_cuda_op_mul_mat_vec(
     GGML_UNUSED(src1_ncols);
     GGML_UNUSED(src1_padded_row_size);
 }
+
+bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) {
+    if (src0_ne[0] % 2 != 0) {
+        return false;
+    }
+    switch (type) {
+        case GGML_TYPE_F32:
+            if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
+                if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
+                    return ne11 <= 8;
+                }
+                if (cc >= GGML_CUDA_CC_TURING) {
+                    return ne11 <= 4;
+                }
+                return ne11 <= 3;
+            }
+            return ne11 <= 8;
+        case GGML_TYPE_F16:
+            if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
+                const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
+                if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
+                    return src0_small && ne11 <= 4;
+                }
+                if (fp16_mma_hardware_available(cc)) {
+                    return src0_small && ne11 <= 3;
+                }
+                return ne11 <= 8;
+            }
+            return ne11 <= 8;
+        case GGML_TYPE_BF16:
+            if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
+                const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
+                if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
+                    return src0_small && ne11 <= 4;
+                }
+                if (bf16_mma_hardware_available(cc)) {
+                    return src0_small && ne11 <= 3;
+                }
+                return ne11 <= 8;
+            }
+            return ne11 <= 8;
+        default:
+            return false;
+    }
+}
index 756e7e1cc7fc36cd5874aeb5719157f5dd50c851..1330bcb6a88602b18ad8641b26cf217349e2f33c 100644 (file)
@@ -1,8 +1,5 @@
 #include "common.cuh"
 
-// maximum number of src0 rows with which to use mul_mat_vec over cuBLAS if FP16 tensor cores are available
-#define MMV_MAX_ROWS 512
-
 void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
 
 void ggml_cuda_op_mul_mat_vec(
@@ -10,3 +7,5 @@ void ggml_cuda_op_mul_mat_vec(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
     const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
     const int64_t src1_padded_row_size, cudaStream_t stream);
+
+bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11);