]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: use mmvq for mul-mat-id for small batch sizes (#18958)
authorAman Gupta <redacted>
Tue, 3 Feb 2026 15:31:23 +0000 (23:31 +0800)
committerGitHub <redacted>
Tue, 3 Feb 2026 15:31:23 +0000 (23:31 +0800)
* CUDA: use mmvq for mul-mat-id for small batch sizes

* add mmvq too

* Fix perf issue on ampere. Use mmvf mm-id only for non-nvidia GPUs

* templatize multi_token_path

ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/mmvf.cu
ggml/src/ggml-cuda/mmvf.cuh
ggml/src/ggml-cuda/mmvq.cu

index 1bcd1ab1f8f0eed73017fb9eca1c195ceaea0191..eeb8625dbebaea1870027a9336d49f742b09f6e2 100644 (file)
@@ -2279,13 +2279,19 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
     const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
 
     if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-        if (ne2 == 1) {
+        static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE);
+        if (ne2 <= MMVQ_MAX_BATCH_SIZE) {
             if (ggml_is_quantized(src0->type)) {
-                ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
+                if (ne2 <= 4) {
+                    ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
+                    return;
+                }
             } else {
-                ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst);
+                if (GGML_CUDA_CC_IS_AMD(cc)) {
+                    ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst);
+                    return;
+                }
             }
-            return;
         }
 
         if (ggml_cuda_should_use_mmq(src0->type, cc, ne12, /*n_experts=*/ne02)) {
index 32948e4d7a188ec21d13803cf80fa1cbcfd37042..d91472024296a980aa2bb1399a3a252d0dec26d5 100644 (file)
@@ -4,26 +4,48 @@
 #include "mmvf.cuh"
 #include "convert.cuh"
 
-template <typename T, typename type_acc, int ncols_dst, int block_size, bool has_fusion = false>
+template <typename T, typename type_acc, int ncols_dst, int block_size, bool has_fusion = false, bool is_multi_token_id = false>
 static __global__ void mul_mat_vec_f(
         const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
-        const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
+        const int ncols2, const uint3 nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
         const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
-        const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
+        const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
+        const int ids_stride) {
     const int row         = blockIdx.x;
+    // for MUL_MAT_ID - blockIdx.y = n_expert_used, blockIdx.z = ncols_dst (tokens)
     const int channel_dst = blockIdx.y;
-    const int channel_x   = ids ? ids[channel_dst]          : fastdiv((uint32_t) channel_dst, channel_ratio);
-    const int channel_y   = ids ? channel_dst % nchannels_y : channel_dst;
-    const int sample_dst  = blockIdx.z;
+    const int tid         = threadIdx.x;
+
+    int token_idx;
+    int channel_x;
+    int channel_y;
+    int sample_dst;
+
+    if constexpr (is_multi_token_id) {
+        // Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case
+        token_idx  = blockIdx.z;
+        channel_x  = ids[channel_dst + token_idx * ids_stride];
+        channel_y  = fastmodulo(channel_dst, nchannels_y);
+        sample_dst = 0;
+    } else {
+        token_idx  = ids ? blockIdx.z                                          : 0;
+        channel_x  = ids ? ids[blockIdx.y + token_idx * ids_stride]            : fastdiv((uint32_t) channel_dst, channel_ratio);
+        channel_y  = ids ? fastmodulo(blockIdx.y, nchannels_y)                 : channel_dst;
+        sample_dst = ids ? 0                                                   : blockIdx.z;
+    }
+
     const int sample_x    = fastdiv((uint32_t) sample_dst, sample_ratio);
     const int sample_y    = sample_dst;
-    const int tid         = threadIdx.x;
 
     constexpr int warp_size   = ggml_cuda_get_physical_warp_size();
 
     x   += int64_t(sample_x)  *stride_sample_x   + channel_x  *stride_channel_x   + row*stride_row;
     y   += int64_t(sample_y)  *stride_sample_y   + channel_y  *stride_channel_y;
     dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
+    if constexpr (is_multi_token_id) {
+        y   += token_idx*stride_col_y2*2;
+        dst += token_idx*stride_col_dst;
+    }
 
     bool use_gate = false;
     bool use_bias = false;
@@ -56,8 +78,10 @@ static __global__ void mul_mat_vec_f(
     if (use_gate) {
         gate_x += int64_t(sample_x)  *stride_sample_x   + channel_x  *stride_channel_x   + row*stride_row;
     }
+
+    const int channel_bias = ids ? channel_x : channel_dst;
+
     if constexpr (has_fusion) {
-        const int channel_bias = ids ? channel_x : channel_dst;
         if (use_bias) {
             x_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst;
         }
@@ -349,36 +373,36 @@ static __global__ void mul_mat_vec_f(
     }
 }
 
-template<typename T, typename type_acc, int ncols_dst, int block_size>
+template<typename T, typename type_acc, int ncols_dst, int block_size, bool is_multi_token_id = false>
 static void mul_mat_vec_f_switch_fusion(
         const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
-        const int64_t ncols, const int64_t nrows,
+        const int64_t ncols, const uint3 nchannels_y,
         const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
         const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
         const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
-        const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const cudaStream_t stream) {
+        const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const int ids_stride, const cudaStream_t stream) {
 
     const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
     if constexpr (ncols_dst == 1) {
         if (has_fusion) {
-            mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true><<<block_nums, block_dims, nbytes_shared, stream>>>
-                (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+            mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
+                (x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
             return;
        }
     }
 
     GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
 
-    mul_mat_vec_f<T, type_acc, ncols_dst, block_size><<<block_nums, block_dims, nbytes_shared, stream>>>
-        (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+    mul_mat_vec_f<T, type_acc, ncols_dst, block_size, false, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
+        (x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
         channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-        sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
 
 }
 
-template <typename T, typename type_acc, int ncols_dst>
+template <typename T, typename type_acc, int ncols_dst, bool is_multi_token_id = false>
 void launch_mul_mat_vec_f_cuda(
         const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
         const int64_t ncols, const int64_t nrows,
@@ -386,12 +410,13 @@ void launch_mul_mat_vec_f_cuda(
         const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
         const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
-        cudaStream_t stream) {
+        const int64_t nsamples_or_ntokens, const int64_t ids_stride, cudaStream_t stream) {
     GGML_ASSERT(ncols        % 2 == 0);
     GGML_ASSERT(stride_row   % 2 == 0);
     GGML_ASSERT(stride_col_y % 2 == 0);
     GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
     GGML_ASSERT(       nsamples_dst  % nsamples_x  == 0);
+    const uint3 nchannels_y_fd   = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0);
     const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);
     const uint3 sample_ratio_fd  = init_fastdiv_values(nsamples_dst  / nsamples_x);
 
@@ -415,56 +440,56 @@ void launch_mul_mat_vec_f_cuda(
     const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
 
     const int nbytes_shared = warp_size*sizeof(float) + (has_fusion ? warp_size*sizeof(float) : 0);
-    const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
+    const dim3 block_nums(nrows, nchannels_dst, nsamples_or_ntokens);
     const dim3 block_dims(block_size_best, 1, 1);
     switch (block_size_best) {
         case   32: {
-            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 32>
-                (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 32, is_multi_token_id>
+                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
         } break;
         case   64: {
-            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 64>
-                (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 64, is_multi_token_id>
+                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
         } break;
         case   96: {
-            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 96>
-                (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 96, is_multi_token_id>
+                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
         } break;
         case  128: {
-            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 128>
-                (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 128, is_multi_token_id>
+                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
         } break;
         case  160: {
-            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 160>
-                (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 160, is_multi_token_id>
+                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
         } break;
         case  192: {
-            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 192>
-                (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 192, is_multi_token_id>
+                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
         } break;
         case  224: {
-            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 224>
-                (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 224, is_multi_token_id>
+                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
         } break;
         case  256: {
-            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 256>
-                (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+            mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 256, is_multi_token_id>
+                (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
+                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
         } break;
         default: {
             GGML_ABORT("fatal error");
@@ -480,55 +505,88 @@ static void mul_mat_vec_f_cuda_switch_ncols_dst(
         const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
         const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
-        cudaStream_t stream) {
+        const int64_t ids_stride, cudaStream_t stream) {
+
+    const bool has_ids = ids != nullptr;
+
+    if (has_ids && ncols_dst > 1) {
+        // Multi-token MUL_MAT_ID path only - single-token goes through regular path below
+        constexpr int c_ncols_dst = 1;
+        launch_mul_mat_vec_f_cuda<T, type_acc, c_ncols_dst, true>
+            (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+             nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+             stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+             ncols_dst, ids_stride, stream);
+        return;
+    }
+
+    if (has_ids) {
+        // Single-token MUL_MAT_ID path
+        constexpr int c_ncols_dst = 1;
+        launch_mul_mat_vec_f_cuda<T, type_acc, c_ncols_dst>
+            (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+             nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+             stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+             ncols_dst, ids_stride, stream);
+        return;
+    }
+
     switch (ncols_dst) {
         case 1:
             launch_mul_mat_vec_f_cuda<T, type_acc, 1>
                 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 nsamples_dst, ids_stride, stream);
             break;
         case 2:
             launch_mul_mat_vec_f_cuda<T, type_acc, 2>
                 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 nsamples_dst, ids_stride, stream);
             break;
         case 3:
             launch_mul_mat_vec_f_cuda<T, type_acc, 3>
                 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 nsamples_dst, ids_stride, stream);
             break;
         case 4:
             launch_mul_mat_vec_f_cuda<T, type_acc, 4>
                 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 nsamples_dst, ids_stride, stream);
             break;
         case 5:
             launch_mul_mat_vec_f_cuda<T, type_acc, 5>
                 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 nsamples_dst, ids_stride, stream);
             break;
         case 6:
             launch_mul_mat_vec_f_cuda<T, type_acc, 6>
                 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 nsamples_dst, ids_stride, stream);
             break;
         case 7:
             launch_mul_mat_vec_f_cuda<T, type_acc, 7>
                 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 nsamples_dst, ids_stride, stream);
             break;
         case 8:
             launch_mul_mat_vec_f_cuda<T, type_acc, 8>
                 (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 nsamples_dst, ids_stride, stream);
             break;
         default:
             GGML_ABORT("fatal error");
@@ -544,21 +602,21 @@ static void mul_mat_vec_f_cuda(
         const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
         const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
-        enum ggml_prec prec, cudaStream_t stream) {
+        const int64_t ids_stride, enum ggml_prec prec, cudaStream_t stream) {
 
     if constexpr(std::is_same_v<T, half>) {
         if (prec == GGML_PREC_DEFAULT) {
             mul_mat_vec_f_cuda_switch_ncols_dst<T, half>
                 (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-                stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             return;
         }
     }
     mul_mat_vec_f_cuda_switch_ncols_dst<T, float>
         (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
         nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
-        stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+        stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
 }
 
 void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
@@ -573,7 +631,7 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor
     const size_t ts_src1 = ggml_type_size(src1->type);
     const size_t ts_dst  = ggml_type_size(dst->type);
 
-    GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for  batch size 1.
+    GGML_ASSERT(!ids || ne12 <= MMVF_MAX_BATCH_SIZE);
     GGML_ASSERT(ne13 == ne3);
 
     GGML_ASSERT(        nb00       == ts_src0);
@@ -626,29 +684,31 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor
     const int64_t ncols_dst          = ids ? ne2  : ne1;
     const int64_t nchannels_y        = ids ? ne11 : ne12;
     const int64_t nchannels_dst      = ids ? ne1  : ne2;
+    const int64_t stride_col_dst     = ids ? s2   : s1;
+    const int64_t stride_col_y       = ids ? s12  : s11;
     const int64_t stride_channel_dst = ids ? s1   : s2;
     const int64_t stride_channel_y   = ids ? s11  : s12;
 
-    GGML_ASSERT(!ids || ncols_dst == 1);
+    const int64_t ids_stride = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
 
     switch (src0->type) {
         case GGML_TYPE_F32: {
             const float * src0_d = (const float *) src0->data;
-            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
+            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
                 ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
-                ne03,              ne3,           s03, s13,              s3,                 prec, ctx.stream());
+                ne03,              ne3,           s03, s13,              s3,                 ids_stride, prec, ctx.stream());
         } break;
         case GGML_TYPE_F16: {
             const half * src0_d = (const half *) src0->data;
-            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
+            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
                 ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
-                ne03,              ne3,           s03, s13,              s3,                 prec, ctx.stream());
+                ne03,              ne3,           s03, s13,              s3,                 ids_stride, prec, ctx.stream());
         } break;
         case GGML_TYPE_BF16: {
             const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
-            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
+            mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
                 ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
-                ne03,              ne3,           s03, s13,              s3,                 prec, ctx.stream());
+                ne03,              ne3,           s03, s13,              s3,                 ids_stride, prec, ctx.stream());
         } break;
         default:
             GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
@@ -695,19 +755,19 @@ void ggml_cuda_op_mul_mat_vec_f(
             const float * src0_d = (const float *) src0_dd_i;
             mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream);
         } break;
         case GGML_TYPE_F16: {
             const half * src0_d = (const half *) src0_dd_i;
             mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream);
         } break;
         case GGML_TYPE_BF16: {
             const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
             mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream);
         } break;
         default:
             GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
index a09fbdc72022e68fc275db9afe8e29fd5835de3b..a50f7c02180132079f81baa5a63d886b795f52c1 100644 (file)
@@ -1,5 +1,7 @@
 #include "common.cuh"
 
+#define MMVF_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVF kernels.
+
 void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
     const ggml_cuda_mm_fusion_args_host * fusion = nullptr);
 
index d671551c17103b9a77c8a6645345912e6d63f28d..ce25ccf427cf0fa897a9ad892b2e0481e505eeb4 100644 (file)
@@ -137,15 +137,15 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int
     return 1;
 }
 
-// tell the compiler to use as many registers as it wants, see nwarps definition below
-template <ggml_type type, int ncols_dst, bool has_fusion>
+template <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false>
 __launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
 static __global__ void mul_mat_vec_q(
         const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
         const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
         const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
         const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
-        const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) {
+        const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst,
+        const uint32_t ids_stride) {
 
     constexpr int qk  = ggml_cuda_type_traits<type>::qk;
     constexpr int qi  = ggml_cuda_type_traits<type>::qi;
@@ -162,11 +162,25 @@ static __global__ void mul_mat_vec_q(
     const     int blocks_per_row_x = ncols_x / qk;
     constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
 
-    // The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1.
     const uint32_t channel_dst = blockIdx.y;
-    const uint32_t channel_x   = ncols_dst == 1 && ids ? ids[channel_dst]                     : fastdiv(channel_dst, channel_ratio);
-    const uint32_t channel_y   = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
-    const uint32_t sample_dst  = blockIdx.z;
+
+    uint32_t token_idx = 0;
+    uint32_t channel_x;
+    uint32_t channel_y;
+    uint32_t sample_dst;
+
+    if constexpr (is_multi_token_id) {
+        // Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case
+        token_idx  = blockIdx.z;
+        channel_x  = ids[channel_dst + token_idx * ids_stride];
+        channel_y  = fastmodulo(channel_dst, nchannels_y);
+        sample_dst = 0;
+    } else {
+        channel_x  = ncols_dst == 1 && ids ? ids[channel_dst]                     : fastdiv(channel_dst, channel_ratio);
+        channel_y  = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
+        sample_dst = blockIdx.z;
+    }
+
     const uint32_t sample_x    = fastdiv(sample_dst, sample_ratio);
     const uint32_t sample_y    = sample_dst;
 
@@ -188,11 +202,11 @@ static __global__ void mul_mat_vec_q(
         active_glu    = fusion.glu_op;
     }
 
-    const uint32_t channel_bias = ids ? channel_x : channel_dst;
 
     float x_biases[ncols_dst]    = { 0.0f };
     float gate_biases[ncols_dst] = { 0.0f };
     if constexpr (has_fusion) {
+        const uint32_t channel_bias = ids ? channel_x : channel_dst;
         if (use_bias) {
             x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
             // 1. Hide latency by prefetching bias and gate here
@@ -222,6 +236,9 @@ static __global__ void mul_mat_vec_q(
     float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}};
 
     const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y;
+    if constexpr (is_multi_token_id) {
+        y += token_idx*stride_col_y;
+    }
     const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x;
 
     for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
@@ -275,6 +292,10 @@ static __global__ void mul_mat_vec_q(
 
     dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0;
 
+    if constexpr (is_multi_token_id) {
+        dst += token_idx*stride_col_dst;
+    }
+
     // sum up partial sums and write back result
 #pragma unroll
     for (int j = 0; j < ncols_dst; ++j) {
@@ -335,40 +356,41 @@ static __global__ void mul_mat_vec_q(
 }
 
 static std::pair<dim3, dim3> calc_launch_params(
-        const int ncols_dst, const int nrows_x, const int nchannels_y, const int nsamples_y,
+        const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens,
         const int warp_size, const mmvq_parameter_table_id table_id) {
     const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id);
-    const dim3 block_nums(nblocks, nchannels_y, nsamples_y);
+    const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens);
     const dim3 block_dims(warp_size, calc_nwarps(ncols_dst, table_id), 1);
     return {block_nums, block_dims};
 }
 
-template<ggml_type type, int c_ncols_dst>
+template<ggml_type type, int c_ncols_dst, bool is_multi_token_id = false>
 static void mul_mat_vec_q_switch_fusion(
         const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
         const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
         const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
         const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
         const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst,
-        const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared, cudaStream_t stream) {
+        const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared,
+        const uint32_t ids_stride, cudaStream_t stream) {
 
     const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
     if constexpr (c_ncols_dst == 1) {
         if (has_fusion) {
-            mul_mat_vec_q<type, c_ncols_dst, true><<<block_nums, block_dims, nbytes_shared, stream>>>
+            mul_mat_vec_q<type, c_ncols_dst, true, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
                 (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
             return;
         }
     }
 
     GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
 
-    mul_mat_vec_q<type, c_ncols_dst, false><<<block_nums, block_dims, nbytes_shared, stream>>>
+    mul_mat_vec_q<type, c_ncols_dst, false, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
         (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
         channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-        sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+        sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
 }
 
 template <ggml_type type>
@@ -379,7 +401,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
         const int nchannels_x, const int nchannels_y, const int nchannels_dst,
         const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
         const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
-        cudaStream_t stream) {
+        const int ids_stride, cudaStream_t stream) {
 
     GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
     GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE);
@@ -393,8 +415,19 @@ static void mul_mat_vec_q_switch_ncols_dst(
     const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
 
     const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
+    const bool has_ids = ids != nullptr;
+
+    if (has_ids && ncols_dst > 1) {
+        // Multi-token MUL_MAT_ID path only - single-token goes through regular path below
+        constexpr int c_ncols_dst = 1;
+        std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id);
+        mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
+             channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+             sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
+             dims.first, dims.second, 0, ids_stride, stream);
+        return;
+    }
 
-    GGML_ASSERT(!ids || ncols_dst == 1);
     switch (ncols_dst) {
         case 1: {
             constexpr int c_ncols_dst = 1;
@@ -402,7 +435,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
             mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 dims.first, dims.second, 0, stream);
+                 dims.first, dims.second, 0, ids_stride, stream);
         } break;
         case 2: {
             constexpr int c_ncols_dst = 2;
@@ -410,7 +443,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
             mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 dims.first, dims.second, 0, stream);
+                 dims.first, dims.second, 0, ids_stride, stream);
         } break;
         case 3: {
             constexpr int c_ncols_dst = 3;
@@ -418,7 +451,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
             mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 dims.first, dims.second, 0, stream);
+                 dims.first, dims.second, 0, ids_stride, stream);
         } break;
         case 4: {
             constexpr int c_ncols_dst = 4;
@@ -426,7 +459,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
             mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 dims.first, dims.second, 0, stream);
+                 dims.first, dims.second, 0, ids_stride, stream);
         } break;
         case 5: {
             constexpr int c_ncols_dst = 5;
@@ -434,7 +467,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
             mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 dims.first, dims.second, 0, stream);
+                 dims.first, dims.second, 0, ids_stride, stream);
         } break;
         case 6: {
             constexpr int c_ncols_dst = 6;
@@ -442,7 +475,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
             mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 dims.first, dims.second, 0, stream);
+                 dims.first, dims.second, 0, ids_stride, stream);
         } break;
         case 7: {
             constexpr int c_ncols_dst = 7;
@@ -450,7 +483,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
             mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 dims.first, dims.second, 0, stream);
+                 dims.first, dims.second, 0, ids_stride, stream);
         } break;
         case 8: {
             constexpr int c_ncols_dst = 8;
@@ -458,7 +491,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
             mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
-                 dims.first, dims.second, 0, stream);
+                 dims.first, dims.second, 0, ids_stride, stream);
         } break;
         default:
             GGML_ABORT("fatal error");
@@ -474,127 +507,127 @@ static void mul_mat_vec_q_switch_type(
         const int nchannels_x, const int nchannels_y, const int nchannels_dst,
         const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
         const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
-        cudaStream_t stream) {
+        const int ids_stride, cudaStream_t stream) {
     switch (type_x) {
         case GGML_TYPE_Q4_0:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_0>
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_Q4_1:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_1>
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_Q5_0:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_0>
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_Q5_1:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_1>
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_Q8_0:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q8_0>
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_MXFP4:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_MXFP4>
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_Q2_K:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_Q3_K:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q3_K>
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_Q4_K:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_K>
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_Q5_K:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_K>
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_Q6_K:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q6_K>
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_IQ2_XXS:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XXS>
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_IQ2_XS:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XS>
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_IQ2_S:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_S>
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_IQ3_XXS:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_XXS>
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_IQ1_S:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_S>
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_IQ1_M:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_M>
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_IQ4_NL:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_NL>
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_IQ4_XS:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_XS>
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         case GGML_TYPE_IQ3_S:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_S>
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
         default:
             GGML_ABORT("fatal error");
@@ -622,7 +655,7 @@ void ggml_cuda_mul_mat_vec_q(
     GGML_ASSERT(        nb0        == ts_dst);
     GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
 
-    GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1.
+    GGML_ASSERT(!ids || ne12 <= MMVQ_MAX_BATCH_SIZE);
 
     const float   * src1_d =       (const float   *) src1->data;
     const int32_t *  ids_d = ids ? (const int32_t *)  ids->data : nullptr;
@@ -693,11 +726,13 @@ void ggml_cuda_mul_mat_vec_q(
     const int64_t stride_channel_dst = ids ? s1   : s2;
     const int64_t stride_channel_y   = ids ? s11  : s12;
 
+    const int64_t ids_stride = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
+
     mul_mat_vec_q_switch_type(
         src0->data, src0->type, src1_q8_1.get(), ids_d, fusion_local, dst_d, ne00,
         ne01,              ncols_dst,     s01, stride_col_y,     stride_col_dst,
         ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
-        ne03,              ne3,           s03, s13,              s3,               stream);
+        ne03,              ne3,           s03, s13,              s3,               ids_stride, stream);
 }
 
 void ggml_cuda_op_mul_mat_vec_q(
@@ -726,7 +761,7 @@ void ggml_cuda_op_mul_mat_vec_q(
     ggml_cuda_mm_fusion_args_device fusion_local{};
     mul_mat_vec_q_switch_type(
         src0_dd_i, src0->type, src1_ddq_i, nullptr, fusion_local, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst,
-        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, stream);
+        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, stream);
 
     GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_ncols, src1_padded_row_size);
 }