]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
Optimize MOE GEMV kernel for BS > 1. (llama/20905)
authorGaurav Garg <redacted>
Sun, 29 Mar 2026 16:35:18 +0000 (22:05 +0530)
committerGeorgi Gerganov <redacted>
Wed, 1 Apr 2026 13:00:26 +0000 (16:00 +0300)
* Optimize MOE GEMV kernel for BS > 1.

The previous MOE kernel for BS > 1 had too many thread blocks (nrows_x, nchannels_dst, ncols_dst), with very little work per block. block of (32, 4) was doing inner dot product for a single row.

New mul_mat_vec_q_moe kernel is dedicated for MoE multi-token kernel with grid (ceil(nrows_x/rpb), nchannels_dst), block (warp_size, ncols_dst). Each warp handles two rows independently with warp-level reduction only (no shared memory sync).

This change doesn't increase any compilation time as a single template instance is needed per type. This also simplifies the original GEMV kernel and gets rid of `is_multi_token_id` specialization.

* Remove em-dashes

* Cherry-pick changes from @am17an PR https://github.com/ggml-org/llama.cpp/pull/20885 to enable small_k optimization only for cases where it benefits

Increase max batch size for MMVQ kernels for MUL_MAT_ID to 8

* Make the max batch size for MOE GEMV kernel configurable based on GPU arch and datatype

---------

Co-authored-by: Aman Gupta <redacted>
src/ggml-cuda/ggml-cuda.cu
src/ggml-cuda/mmvq.cu
src/ggml-cuda/mmvq.cuh

index cc80eb3ffc2a4c1830bcf1e70227f85bcf6972ae..d1239b1c5f72060b454d4221c1fccf73a5900cf8 100644 (file)
@@ -2343,7 +2343,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
         static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE);
         if (ne2 <= MMVQ_MAX_BATCH_SIZE) {
             if (ggml_is_quantized(src0->type)) {
-                if (ne2 <= MMVQ_MMID_MAX_BATCH_SIZE) {
+                const int mmvq_mmid_max = get_mmvq_mmid_max_batch(src0->type, cc);
+                if (ne2 <= mmvq_mmid_max) {
                     ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
                     return;
                 }
@@ -2946,14 +2947,18 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
         }
 
         // [TAG_MUL_MAT_ID_CUDA_GRAPHS]
-        if (node->op == GGML_OP_MUL_MAT_ID && (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > MMVQ_MMID_MAX_BATCH_SIZE)) {
-            // under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs
-            // TODO: figure out a way to enable for larger batch sizes, without hurting performance
-            // ref: https://github.com/ggml-org/llama.cpp/pull/18958
-            use_cuda_graph = false;
+        if (node->op == GGML_OP_MUL_MAT_ID) {
+            const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+            const int mmvq_mmid_max = get_mmvq_mmid_max_batch(node->src[0]->type, cc);
+            if (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > mmvq_mmid_max) {
+                // under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs
+                // TODO: figure out a way to enable for larger batch sizes, without hurting performance
+                // ref: https://github.com/ggml-org/llama.cpp/pull/18958
+                use_cuda_graph = false;
 #ifndef NDEBUG
-            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
+                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__);
 #endif
+            }
         }
 
         if (!use_cuda_graph) {
index 66bd8beeae7b2a702d1799bc3d918ef3ea3bbab7..8d80d1dd9a78413afd9c9d510f095e874100036b 100644 (file)
@@ -97,6 +97,194 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
     return MMVQ_PARAMETERS_GENERIC;
 }
 
+// Per-architecture maximum batch size for which MMVQ should be used for MUL_MAT_ID.
+// Returns a value <= MMVQ_MAX_BATCH_SIZE. Default is MMVQ_MAX_BATCH_SIZE.
+// Check https://github.com/ggml-org/llama.cpp/pull/20905#issuecomment-4145835627 for details
+
+static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_pascal_older(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_IQ1_S:   return 6;
+        case GGML_TYPE_IQ1_M:   return 6;
+        case GGML_TYPE_IQ2_S:   return 4;
+        case GGML_TYPE_IQ2_XS:  return 5;
+        case GGML_TYPE_IQ2_XXS: return 5;
+        case GGML_TYPE_IQ3_S:   return 4;
+        case GGML_TYPE_IQ3_XXS: return 4;
+        case GGML_TYPE_IQ4_NL:  return 6;
+        case GGML_TYPE_IQ4_XS:  return 5;
+        case GGML_TYPE_MXFP4:   return 4;
+        case GGML_TYPE_Q2_K:    return 4;
+        case GGML_TYPE_Q3_K:    return 4;
+        case GGML_TYPE_Q4_0:    return 6;
+        case GGML_TYPE_Q4_1:    return 6;
+        case GGML_TYPE_Q4_K:    return 5;
+        case GGML_TYPE_Q5_0:    return 6;
+        case GGML_TYPE_Q5_1:    return 6;
+        case GGML_TYPE_Q5_K:    return 5;
+        case GGML_TYPE_Q6_K:    return 4;
+        case GGML_TYPE_Q8_0:    return 4;
+        default:                return MMVQ_MAX_BATCH_SIZE;
+    }
+}
+
+static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_turing_plus(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_IQ2_S:   return 7;
+        case GGML_TYPE_IQ3_S:   return 6;
+        case GGML_TYPE_IQ3_XXS: return 7;
+        case GGML_TYPE_MXFP4:   return 7;
+        case GGML_TYPE_Q2_K:    return 7;
+        case GGML_TYPE_Q3_K:    return 5;
+        default:                return MMVQ_MAX_BATCH_SIZE;
+    }
+}
+
+static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_gcn(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_IQ1_S:   return 5;
+        case GGML_TYPE_IQ1_M:   return 5;
+        case GGML_TYPE_IQ2_S:   return 4;
+        case GGML_TYPE_IQ2_XS:  return 4;
+        case GGML_TYPE_IQ2_XXS: return 4;
+        case GGML_TYPE_IQ3_S:   return 4;
+        case GGML_TYPE_IQ3_XXS: return 4;
+        case GGML_TYPE_IQ4_NL:  return 6;
+        case GGML_TYPE_IQ4_XS:  return 4;
+        case GGML_TYPE_Q2_K:    return 4;
+        case GGML_TYPE_Q3_K:    return 4;
+        case GGML_TYPE_Q4_0:    return 5;
+        case GGML_TYPE_Q4_1:    return 5;
+        case GGML_TYPE_Q4_K:    return 4;
+        case GGML_TYPE_Q5_K:    return 4;
+        case GGML_TYPE_Q6_K:    return 4;
+        case GGML_TYPE_Q8_0:    return 4;
+        default:                return MMVQ_MAX_BATCH_SIZE;
+    }
+}
+
+static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_cdna(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_IQ2_S:   return 5;
+        case GGML_TYPE_IQ2_XS:  return 5;
+        case GGML_TYPE_IQ2_XXS: return 5;
+        case GGML_TYPE_IQ3_S:   return 4;
+        case GGML_TYPE_IQ3_XXS: return 5;
+        default:                return MMVQ_MAX_BATCH_SIZE;
+    }
+}
+
+static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna1_rdna2(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_IQ2_S:   return 4;
+        case GGML_TYPE_IQ2_XS:  return 4;
+        case GGML_TYPE_IQ2_XXS: return 4;
+        case GGML_TYPE_IQ3_S:   return 4;
+        case GGML_TYPE_IQ3_XXS: return 4;
+        case GGML_TYPE_Q2_K:    return 7;
+        case GGML_TYPE_Q3_K:    return 4;
+        case GGML_TYPE_Q4_K:    return 5;
+        case GGML_TYPE_Q5_K:    return 6;
+        case GGML_TYPE_Q6_K:    return 5;
+        default:                return MMVQ_MAX_BATCH_SIZE;
+    }
+}
+
+static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna3(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_IQ1_S:   return 6;
+        case GGML_TYPE_IQ1_M:   return 6;
+        case GGML_TYPE_IQ2_S:   return 4;
+        case GGML_TYPE_IQ2_XS:  return 4;
+        case GGML_TYPE_IQ2_XXS: return 4;
+        case GGML_TYPE_IQ3_S:   return 4;
+        case GGML_TYPE_IQ3_XXS: return 4;
+        case GGML_TYPE_IQ4_NL:  return 6;
+        case GGML_TYPE_IQ4_XS:  return 6;
+        case GGML_TYPE_Q4_K:    return 4;
+        case GGML_TYPE_Q5_K:    return 4;
+        case GGML_TYPE_Q6_K:    return 4;
+        default:                return MMVQ_MAX_BATCH_SIZE;
+    }
+}
+
+static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna4(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_IQ1_S:   return 7;
+        case GGML_TYPE_IQ1_M:   return 7;
+        case GGML_TYPE_IQ2_S:   return 4;
+        case GGML_TYPE_IQ2_XS:  return 4;
+        case GGML_TYPE_IQ2_XXS: return 4;
+        case GGML_TYPE_IQ3_S:   return 4;
+        case GGML_TYPE_IQ3_XXS: return 4;
+        case GGML_TYPE_IQ4_NL:  return 7;
+        case GGML_TYPE_IQ4_XS:  return 5;
+        case GGML_TYPE_MXFP4:   return 5;
+        case GGML_TYPE_Q3_K:    return 4;
+        case GGML_TYPE_Q4_0:    return 7;
+        case GGML_TYPE_Q4_1:    return 7;
+        case GGML_TYPE_Q4_K:    return 4;
+        case GGML_TYPE_Q5_0:    return 7;
+        case GGML_TYPE_Q5_1:    return 7;
+        case GGML_TYPE_Q5_K:    return 5;
+        case GGML_TYPE_Q6_K:    return 5;
+        case GGML_TYPE_Q8_0:    return 7;
+        default:                return MMVQ_MAX_BATCH_SIZE;
+    }
+}
+
+// Host function: returns the max batch size for the current arch+type at runtime.
+int get_mmvq_mmid_max_batch(ggml_type type, int cc) {
+    // NVIDIA: Volta, Ada Lovelace, and Blackwell always use MMVQ for MUL_MAT_ID.
+    if (cc == GGML_CUDA_CC_VOLTA || cc >= GGML_CUDA_CC_ADA_LOVELACE) {
+        return MMVQ_MAX_BATCH_SIZE;
+    }
+    if (cc >= GGML_CUDA_CC_TURING) {
+        return get_mmvq_mmid_max_batch_turing_plus(type);
+    }
+    if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
+        return get_mmvq_mmid_max_batch_pascal_older(type);
+    }
+    // AMD
+    if (GGML_CUDA_CC_IS_RDNA4(cc)) {
+        return get_mmvq_mmid_max_batch_rdna4(type);
+    }
+    if (GGML_CUDA_CC_IS_RDNA3(cc)) {
+        return get_mmvq_mmid_max_batch_rdna3(type);
+    }
+    if (GGML_CUDA_CC_IS_RDNA1(cc) || GGML_CUDA_CC_IS_RDNA2(cc)) {
+        return get_mmvq_mmid_max_batch_rdna1_rdna2(type);
+    }
+    if (GGML_CUDA_CC_IS_CDNA(cc)) {
+        return get_mmvq_mmid_max_batch_cdna(type);
+    }
+    if (GGML_CUDA_CC_IS_GCN(cc)) {
+        return get_mmvq_mmid_max_batch_gcn(type);
+    }
+    return MMVQ_MAX_BATCH_SIZE;
+}
+
+// Device constexpr: returns the max batch size for the current arch+type at compile time.
+template <ggml_type type>
+static constexpr __device__ int get_mmvq_mmid_max_batch_for_device() {
+#if defined(RDNA4)
+    return get_mmvq_mmid_max_batch_rdna4(type);
+#elif defined(RDNA3)
+    return get_mmvq_mmid_max_batch_rdna3(type);
+#elif defined(RDNA2) || defined(RDNA1)
+    return get_mmvq_mmid_max_batch_rdna1_rdna2(type);
+#elif defined(CDNA)
+    return get_mmvq_mmid_max_batch_cdna(type);
+#elif defined(GCN)
+    return get_mmvq_mmid_max_batch_gcn(type);
+#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || __CUDA_ARCH__ >= GGML_CUDA_CC_ADA_LOVELACE)
+    return MMVQ_MAX_BATCH_SIZE;
+#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
+    return get_mmvq_mmid_max_batch_turing_plus(type);
+#else
+    return get_mmvq_mmid_max_batch_pascal_older(type);
+#endif
+}
+
 static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_dst, mmvq_parameter_table_id table_id) {
     if (table_id == MMVQ_PARAMETERS_GENERIC) {
         switch (ncols_dst) {
@@ -195,7 +383,7 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int
     return 1;
 }
 
-template <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false, bool small_k = false>
+template <ggml_type type, int ncols_dst, bool has_fusion, bool small_k = false>
 __launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
 static __global__ void mul_mat_vec_q(
         const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
@@ -222,22 +410,13 @@ static __global__ void mul_mat_vec_q(
 
     const uint32_t channel_dst = blockIdx.y;
 
-    uint32_t token_idx = 0;
     uint32_t channel_x;
     uint32_t channel_y;
     uint32_t sample_dst;
 
-    if constexpr (is_multi_token_id) {
-        // Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case
-        token_idx  = blockIdx.z;
-        channel_x  = ids[channel_dst + token_idx * ids_stride];
-        channel_y  = fastmodulo(channel_dst, nchannels_y);
-        sample_dst = 0;
-    } else {
-        channel_x  = ncols_dst == 1 && ids ? ids[channel_dst]                     : fastdiv(channel_dst, channel_ratio);
-        channel_y  = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
-        sample_dst = blockIdx.z;
-    }
+    channel_x  = ncols_dst == 1 && ids ? ids[channel_dst]                     : fastdiv(channel_dst, channel_ratio);
+    channel_y  = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
+    sample_dst = blockIdx.z;
 
     const uint32_t sample_x    = fastdiv(sample_dst, sample_ratio);
     const uint32_t sample_y    = sample_dst;
@@ -294,9 +473,6 @@ static __global__ void mul_mat_vec_q(
     float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}};
 
     const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y;
-    if constexpr (is_multi_token_id) {
-        y += token_idx*stride_col_y;
-    }
     const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x;
 
     for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
@@ -350,10 +526,6 @@ static __global__ void mul_mat_vec_q(
 
     dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0;
 
-    if constexpr (is_multi_token_id) {
-        dst += token_idx*stride_col_dst;
-    }
-
     // sum up partial sums and write back result
 #pragma unroll
     for (int j = 0; j < ncols_dst; ++j) {
@@ -413,6 +585,69 @@ static __global__ void mul_mat_vec_q(
     }
 }
 
+// Dedicated MoE multi-token kernel.
+// Grid: (ceil(nrows_x / c_rows_per_block), nchannels_dst)
+// Block: (warp_size, ncols_dst) - each warp handles one token independently.
+// No shared memory reduction needed since each warp works alone.
+template <ggml_type type, int c_rows_per_block>
+__launch_bounds__(get_mmvq_mmid_max_batch_for_device<type>()*ggml_cuda_get_physical_warp_size(), 1)
+static __global__ void mul_mat_vec_q_moe(
+        const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids,
+        float * __restrict__ dst,
+        const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t nrows_x,
+        const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst,
+        const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst,
+        const uint32_t ncols_dst, const uint32_t ids_stride) {
+
+    constexpr int qk  = ggml_cuda_type_traits<type>::qk;
+    constexpr int qi  = ggml_cuda_type_traits<type>::qi;
+    constexpr int vdr = get_vdr_mmvq(type);
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+    constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
+
+    const uint32_t token_idx   = threadIdx.y;
+    const int      row0        = c_rows_per_block*blockIdx.x;
+    const int      blocks_per_row_x = ncols_x / qk;
+    constexpr int  blocks_per_iter  = vdr * warp_size / qi;
+
+    const uint32_t channel_dst = blockIdx.y;
+
+    if (token_idx >= ncols_dst) {
+        return;
+    }
+
+    const uint32_t channel_x = ids[channel_dst + token_idx * ids_stride];
+    const uint32_t channel_y = fastmodulo(channel_dst, nchannels_y);
+
+    const block_q8_1 * y = ((const block_q8_1 *) vy) + channel_y*stride_channel_y + token_idx*stride_col_y;
+    const int kbx_offset  = channel_x*stride_channel_x + row0*stride_row_x;
+
+    // partial sum for each thread
+    float tmp[c_rows_per_block] = {0.0f};
+
+    for (int kbx = threadIdx.x / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
+        const int kby = kbx * (qk/QK8_1);
+        const int kqs = vdr * (threadIdx.x % (qi/vdr));
+
+#pragma unroll
+        for (int i = 0; i < c_rows_per_block; ++i) {
+            tmp[i] += vec_dot_q_cuda(vx, &y[kby], kbx_offset + i*stride_row_x + kbx, kqs);
+        }
+    }
+
+    // Warp-level reduction only - no shared memory needed
+#pragma unroll
+    for (int i = 0; i < c_rows_per_block; ++i) {
+        tmp[i] = warp_reduce_sum<warp_size>(tmp[i]);
+    }
+
+    // Write results
+    if (threadIdx.x < c_rows_per_block && (c_rows_per_block == 1 || uint32_t(row0 + threadIdx.x) < nrows_x)) {
+        dst[channel_dst*stride_channel_dst + token_idx*stride_col_dst + row0 + threadIdx.x] = tmp[threadIdx.x];
+    }
+}
+
 template<ggml_type type>
 static std::pair<dim3, dim3> calc_launch_params(
         const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens,
@@ -425,7 +660,7 @@ static std::pair<dim3, dim3> calc_launch_params(
     return {block_nums, block_dims};
 }
 
-template<ggml_type type, int c_ncols_dst, bool is_multi_token_id = false, bool small_k = false>
+template<ggml_type type, int c_ncols_dst, bool small_k = false>
 static void mul_mat_vec_q_switch_fusion(
         const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
         const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
@@ -438,7 +673,7 @@ static void mul_mat_vec_q_switch_fusion(
     const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
     if constexpr (c_ncols_dst == 1) {
         if (has_fusion) {
-            mul_mat_vec_q<type, c_ncols_dst, true, is_multi_token_id, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
+            mul_mat_vec_q<type, c_ncols_dst, true, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
                 (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
                  channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
@@ -448,12 +683,33 @@ static void mul_mat_vec_q_switch_fusion(
 
     GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
 
-    mul_mat_vec_q<type, c_ncols_dst, false, is_multi_token_id, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
+    mul_mat_vec_q<type, c_ncols_dst, false, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
         (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
         channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
         sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
 }
 
+template <ggml_type type>
+static void mul_mat_vec_q_moe_launch(
+        const void * vx, const void * vy, const int32_t * ids, float * dst,
+        const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t nrows_x,
+        const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst,
+        const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst,
+        const uint32_t ncols_dst, const uint32_t ids_stride,
+        const int warp_size, const int nchannels_dst, cudaStream_t stream) {
+
+    constexpr int rows_per_block = 2; // 2 gives best perf based on tuning
+    const int64_t nblocks_rows = (nrows_x + rows_per_block - 1) / rows_per_block;
+    const dim3 block_nums(nblocks_rows, nchannels_dst);
+    const dim3 block_dims(warp_size, ncols_dst);
+
+    mul_mat_vec_q_moe<type, rows_per_block><<<block_nums, block_dims, 0, stream>>>(
+        vx, vy, ids, dst, ncols_x, nchannels_y, nrows_x,
+        stride_row_x, stride_col_y, stride_col_dst,
+        stride_channel_x, stride_channel_y, stride_channel_dst,
+        ncols_dst, ids_stride);
+}
+
 template <ggml_type type>
 static void mul_mat_vec_q_switch_ncols_dst(
         const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
@@ -472,20 +728,62 @@ static void mul_mat_vec_q_switch_ncols_dst(
     const uint3 sample_ratio_fd  = init_fastdiv_values(nsamples_dst  / nsamples_x);
 
     const int device = ggml_cuda_get_device();
+    const int                     cc        = ggml_cuda_info().devices[device].cc;
     const int warp_size = ggml_cuda_info().devices[device].warp_size;
-    const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
+    const mmvq_parameter_table_id table_id  = get_device_table_id(cc);
 
     const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
     const bool has_ids = ids != nullptr;
 
+    const auto should_use_small_k = [&](int c_ncols_dst) {
+        // When K is small, increase rows_per_block to match nwarps so each warp has more work to do
+        // Trigger when the full thread block covers all K blocks in a single loop iteration and few threads remain idle.
+        constexpr int qk                    = ggml_cuda_type_traits<type>::qk;
+        constexpr int qi                    = ggml_cuda_type_traits<type>::qi;
+        constexpr int vdr                   = get_vdr_mmvq(type);
+        const int     blocks_per_row_x      = ncols_x / qk;
+        const int     blocks_per_iter_1warp = vdr * warp_size / qi;
+        const int     nwarps                = calc_nwarps(type, c_ncols_dst, table_id);
+        bool          use                   = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp;
+
+        constexpr std::array<ggml_type, 2> iq_slow_turing = {
+            GGML_TYPE_IQ3_XXS,
+            GGML_TYPE_IQ3_S,
+        };
+        constexpr std::array<ggml_type, 8> iq_slow_other = {
+            GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,   GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS,
+            GGML_TYPE_IQ2_S, GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ3_S,   GGML_TYPE_IQ4_XS,
+        };
+        constexpr std::array<ggml_type, 3> slow_pascal = {
+            GGML_TYPE_IQ3_S,
+            GGML_TYPE_Q2_K,
+            GGML_TYPE_Q3_K,
+        };
+
+        const bool is_nvidia_turing_plus  = GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_TURING;
+        const bool is_nvidia_pascal_older = GGML_CUDA_CC_IS_NVIDIA(cc) && cc < GGML_CUDA_CC_VOLTA;
+
+        if (is_nvidia_turing_plus) {
+            if (ncols_dst == 1 &&
+                    std::find(iq_slow_turing.begin(), iq_slow_turing.end(), type) != iq_slow_turing.end()) {
+                use = false;
+            }
+        } else if ((ncols_dst == 1 && std::find(iq_slow_other.begin(), iq_slow_other.end(), type) != iq_slow_other.end()) ||
+                (is_nvidia_pascal_older && std::find(slow_pascal.begin(), slow_pascal.end(), type) != slow_pascal.end()) ||
+                GGML_CUDA_CC_IS_RDNA(cc)) {
+            use = false;
+        }
+
+        return use;
+    };
+
     if (has_ids && ncols_dst > 1) {
-        // Multi-token MUL_MAT_ID path only - single-token goes through regular path below
-        constexpr int c_ncols_dst = 1;
-        std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id);
-        mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
-             channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-             sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
-             dims.first, dims.second, 0, ids_stride, stream);
+        // Multi-token MUL_MAT_ID path - dedicated MoE kernel
+        mul_mat_vec_q_moe_launch<type>(
+            vx, vy, ids, dst, ncols_x, nchannels_y_fd, nrows_x,
+            stride_row_x, stride_col_y, stride_col_dst,
+            stride_channel_x, stride_channel_y, stride_channel_dst,
+            ncols_dst, ids_stride, warp_size, nchannels_dst, stream);
         return;
     }
 
@@ -493,31 +791,24 @@ static void mul_mat_vec_q_switch_ncols_dst(
         case 1: {
             constexpr int c_ncols_dst = 1;
 
-            // When K is small, increase rows_per_block to match nwarps so each warp has more work to do
-            // Trigger when the full thread block covers all K blocks in a single loop iteration and few threads remain idle.
-            constexpr int qk  = ggml_cuda_type_traits<type>::qk;
-            constexpr int qi  = ggml_cuda_type_traits<type>::qi;
-            constexpr int vdr = get_vdr_mmvq(type);
-            const int blocks_per_row_x = ncols_x / qk;
-            const int blocks_per_iter_1warp = vdr * warp_size / qi;
-            const int nwarps = calc_nwarps(type, c_ncols_dst, table_id);
-            const bool use_small_k = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp;
+            bool use_small_k = should_use_small_k(c_ncols_dst);
+
             if (use_small_k) {
-                std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst,
-                                                                    warp_size, table_id, true);
-                mul_mat_vec_q_switch_fusion<type, c_ncols_dst, false, true>(
+                std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst,
+                                                                        nsamples_dst, warp_size, table_id, true);
+                mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true>(
                     vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
-                    channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                    sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
-                    dims.first, dims.second, 0, ids_stride, stream);
+                    channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd,
+                    stride_sample_x, stride_sample_y, stride_sample_dst, dims.first, dims.second, 0, ids_stride,
+                    stream);
             } else {
-                std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst,
-                                                                    warp_size, table_id);
+                std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst,
+                                                                        nsamples_dst, warp_size, table_id);
                 mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(
                     vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
-                    channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
-                    sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
-                    dims.first, dims.second, 0, ids_stride, stream);
+                    channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd,
+                    stride_sample_x, stride_sample_y, stride_sample_dst, dims.first, dims.second, 0, ids_stride,
+                    stream);
             }
         } break;
         case 2: {
index 8a154631f69ec40fb0d45f55e4045330b520d3b7..6bf0a8e8677ded84736af9fcf2ac9adcafd81c8b 100644 (file)
@@ -1,7 +1,10 @@
 #include "common.cuh"
 
 #define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
-#define MMVQ_MMID_MAX_BATCH_SIZE 4 // Max. batch size for which to use MMVQ kernels for MUL_MAT_ID
+
+// Returns the maximum batch size for which MMVQ should be used for MUL_MAT_ID,
+// based on the quantization type and GPU architecture (compute capability).
+int get_mmvq_mmid_max_batch(ggml_type type, int cc);
 
 void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
     const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr);