return MMVQ_PARAMETERS_GENERIC;
}
+// Per-architecture maximum batch size for which MMVQ should be used for MUL_MAT_ID.
+// Returns a value <= MMVQ_MAX_BATCH_SIZE. Default is MMVQ_MAX_BATCH_SIZE.
+// Check https://github.com/ggml-org/llama.cpp/pull/20905#issuecomment-4145835627 for details
+
+static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_pascal_older(ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_IQ1_S: return 6;
+ case GGML_TYPE_IQ1_M: return 6;
+ case GGML_TYPE_IQ2_S: return 4;
+ case GGML_TYPE_IQ2_XS: return 5;
+ case GGML_TYPE_IQ2_XXS: return 5;
+ case GGML_TYPE_IQ3_S: return 4;
+ case GGML_TYPE_IQ3_XXS: return 4;
+ case GGML_TYPE_IQ4_NL: return 6;
+ case GGML_TYPE_IQ4_XS: return 5;
+ case GGML_TYPE_MXFP4: return 4;
+ case GGML_TYPE_Q2_K: return 4;
+ case GGML_TYPE_Q3_K: return 4;
+ case GGML_TYPE_Q4_0: return 6;
+ case GGML_TYPE_Q4_1: return 6;
+ case GGML_TYPE_Q4_K: return 5;
+ case GGML_TYPE_Q5_0: return 6;
+ case GGML_TYPE_Q5_1: return 6;
+ case GGML_TYPE_Q5_K: return 5;
+ case GGML_TYPE_Q6_K: return 4;
+ case GGML_TYPE_Q8_0: return 4;
+ default: return MMVQ_MAX_BATCH_SIZE;
+ }
+}
+
+static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_turing_plus(ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_IQ2_S: return 7;
+ case GGML_TYPE_IQ3_S: return 6;
+ case GGML_TYPE_IQ3_XXS: return 7;
+ case GGML_TYPE_MXFP4: return 7;
+ case GGML_TYPE_Q2_K: return 7;
+ case GGML_TYPE_Q3_K: return 5;
+ default: return MMVQ_MAX_BATCH_SIZE;
+ }
+}
+
+static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_gcn(ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_IQ1_S: return 5;
+ case GGML_TYPE_IQ1_M: return 5;
+ case GGML_TYPE_IQ2_S: return 4;
+ case GGML_TYPE_IQ2_XS: return 4;
+ case GGML_TYPE_IQ2_XXS: return 4;
+ case GGML_TYPE_IQ3_S: return 4;
+ case GGML_TYPE_IQ3_XXS: return 4;
+ case GGML_TYPE_IQ4_NL: return 6;
+ case GGML_TYPE_IQ4_XS: return 4;
+ case GGML_TYPE_Q2_K: return 4;
+ case GGML_TYPE_Q3_K: return 4;
+ case GGML_TYPE_Q4_0: return 5;
+ case GGML_TYPE_Q4_1: return 5;
+ case GGML_TYPE_Q4_K: return 4;
+ case GGML_TYPE_Q5_K: return 4;
+ case GGML_TYPE_Q6_K: return 4;
+ case GGML_TYPE_Q8_0: return 4;
+ default: return MMVQ_MAX_BATCH_SIZE;
+ }
+}
+
+static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_cdna(ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_IQ2_S: return 5;
+ case GGML_TYPE_IQ2_XS: return 5;
+ case GGML_TYPE_IQ2_XXS: return 5;
+ case GGML_TYPE_IQ3_S: return 4;
+ case GGML_TYPE_IQ3_XXS: return 5;
+ default: return MMVQ_MAX_BATCH_SIZE;
+ }
+}
+
+static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna1_rdna2(ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_IQ2_S: return 4;
+ case GGML_TYPE_IQ2_XS: return 4;
+ case GGML_TYPE_IQ2_XXS: return 4;
+ case GGML_TYPE_IQ3_S: return 4;
+ case GGML_TYPE_IQ3_XXS: return 4;
+ case GGML_TYPE_Q2_K: return 7;
+ case GGML_TYPE_Q3_K: return 4;
+ case GGML_TYPE_Q4_K: return 5;
+ case GGML_TYPE_Q5_K: return 6;
+ case GGML_TYPE_Q6_K: return 5;
+ default: return MMVQ_MAX_BATCH_SIZE;
+ }
+}
+
+static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna3(ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_IQ1_S: return 6;
+ case GGML_TYPE_IQ1_M: return 6;
+ case GGML_TYPE_IQ2_S: return 4;
+ case GGML_TYPE_IQ2_XS: return 4;
+ case GGML_TYPE_IQ2_XXS: return 4;
+ case GGML_TYPE_IQ3_S: return 4;
+ case GGML_TYPE_IQ3_XXS: return 4;
+ case GGML_TYPE_IQ4_NL: return 6;
+ case GGML_TYPE_IQ4_XS: return 6;
+ case GGML_TYPE_Q4_K: return 4;
+ case GGML_TYPE_Q5_K: return 4;
+ case GGML_TYPE_Q6_K: return 4;
+ default: return MMVQ_MAX_BATCH_SIZE;
+ }
+}
+
+static constexpr __host__ __device__ int get_mmvq_mmid_max_batch_rdna4(ggml_type type) {
+ switch (type) {
+ case GGML_TYPE_IQ1_S: return 7;
+ case GGML_TYPE_IQ1_M: return 7;
+ case GGML_TYPE_IQ2_S: return 4;
+ case GGML_TYPE_IQ2_XS: return 4;
+ case GGML_TYPE_IQ2_XXS: return 4;
+ case GGML_TYPE_IQ3_S: return 4;
+ case GGML_TYPE_IQ3_XXS: return 4;
+ case GGML_TYPE_IQ4_NL: return 7;
+ case GGML_TYPE_IQ4_XS: return 5;
+ case GGML_TYPE_MXFP4: return 5;
+ case GGML_TYPE_Q3_K: return 4;
+ case GGML_TYPE_Q4_0: return 7;
+ case GGML_TYPE_Q4_1: return 7;
+ case GGML_TYPE_Q4_K: return 4;
+ case GGML_TYPE_Q5_0: return 7;
+ case GGML_TYPE_Q5_1: return 7;
+ case GGML_TYPE_Q5_K: return 5;
+ case GGML_TYPE_Q6_K: return 5;
+ case GGML_TYPE_Q8_0: return 7;
+ default: return MMVQ_MAX_BATCH_SIZE;
+ }
+}
+
+// Host function: returns the max batch size for the current arch+type at runtime.
+int get_mmvq_mmid_max_batch(ggml_type type, int cc) {
+ // NVIDIA: Volta, Ada Lovelace, and Blackwell always use MMVQ for MUL_MAT_ID.
+ if (cc == GGML_CUDA_CC_VOLTA || cc >= GGML_CUDA_CC_ADA_LOVELACE) {
+ return MMVQ_MAX_BATCH_SIZE;
+ }
+ if (cc >= GGML_CUDA_CC_TURING) {
+ return get_mmvq_mmid_max_batch_turing_plus(type);
+ }
+ if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
+ return get_mmvq_mmid_max_batch_pascal_older(type);
+ }
+ // AMD
+ if (GGML_CUDA_CC_IS_RDNA4(cc)) {
+ return get_mmvq_mmid_max_batch_rdna4(type);
+ }
+ if (GGML_CUDA_CC_IS_RDNA3(cc)) {
+ return get_mmvq_mmid_max_batch_rdna3(type);
+ }
+ if (GGML_CUDA_CC_IS_RDNA1(cc) || GGML_CUDA_CC_IS_RDNA2(cc)) {
+ return get_mmvq_mmid_max_batch_rdna1_rdna2(type);
+ }
+ if (GGML_CUDA_CC_IS_CDNA(cc)) {
+ return get_mmvq_mmid_max_batch_cdna(type);
+ }
+ if (GGML_CUDA_CC_IS_GCN(cc)) {
+ return get_mmvq_mmid_max_batch_gcn(type);
+ }
+ return MMVQ_MAX_BATCH_SIZE;
+}
+
+// Device constexpr: returns the max batch size for the current arch+type at compile time.
+template <ggml_type type>
+static constexpr __device__ int get_mmvq_mmid_max_batch_for_device() {
+#if defined(RDNA4)
+ return get_mmvq_mmid_max_batch_rdna4(type);
+#elif defined(RDNA3)
+ return get_mmvq_mmid_max_batch_rdna3(type);
+#elif defined(RDNA2) || defined(RDNA1)
+ return get_mmvq_mmid_max_batch_rdna1_rdna2(type);
+#elif defined(CDNA)
+ return get_mmvq_mmid_max_batch_cdna(type);
+#elif defined(GCN)
+ return get_mmvq_mmid_max_batch_gcn(type);
+#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || __CUDA_ARCH__ >= GGML_CUDA_CC_ADA_LOVELACE)
+ return MMVQ_MAX_BATCH_SIZE;
+#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
+ return get_mmvq_mmid_max_batch_turing_plus(type);
+#else
+ return get_mmvq_mmid_max_batch_pascal_older(type);
+#endif
+}
+
static constexpr __host__ __device__ int calc_nwarps(ggml_type type, int ncols_dst, mmvq_parameter_table_id table_id) {
if (table_id == MMVQ_PARAMETERS_GENERIC) {
switch (ncols_dst) {
return 1;
}
-template <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false, bool small_k = false>
+template <ggml_type type, int ncols_dst, bool has_fusion, bool small_k = false>
__launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
const uint32_t channel_dst = blockIdx.y;
- uint32_t token_idx = 0;
uint32_t channel_x;
uint32_t channel_y;
uint32_t sample_dst;
- if constexpr (is_multi_token_id) {
- // Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case
- token_idx = blockIdx.z;
- channel_x = ids[channel_dst + token_idx * ids_stride];
- channel_y = fastmodulo(channel_dst, nchannels_y);
- sample_dst = 0;
- } else {
- channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio);
- channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
- sample_dst = blockIdx.z;
- }
+ channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio);
+ channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
+ sample_dst = blockIdx.z;
const uint32_t sample_x = fastdiv(sample_dst, sample_ratio);
const uint32_t sample_y = sample_dst;
float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}};
const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y;
- if constexpr (is_multi_token_id) {
- y += token_idx*stride_col_y;
- }
const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x;
for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0;
- if constexpr (is_multi_token_id) {
- dst += token_idx*stride_col_dst;
- }
-
// sum up partial sums and write back result
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
}
}
+// Dedicated MoE multi-token kernel.
+// Grid: (ceil(nrows_x / c_rows_per_block), nchannels_dst)
+// Block: (warp_size, ncols_dst) - each warp handles one token independently.
+// No shared memory reduction needed since each warp works alone.
+template <ggml_type type, int c_rows_per_block>
+__launch_bounds__(get_mmvq_mmid_max_batch_for_device<type>()*ggml_cuda_get_physical_warp_size(), 1)
+static __global__ void mul_mat_vec_q_moe(
+ const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids,
+ float * __restrict__ dst,
+ const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t nrows_x,
+ const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst,
+ const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst,
+ const uint32_t ncols_dst, const uint32_t ids_stride) {
+
+ constexpr int qk = ggml_cuda_type_traits<type>::qk;
+ constexpr int qi = ggml_cuda_type_traits<type>::qi;
+ constexpr int vdr = get_vdr_mmvq(type);
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+ constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
+
+ const uint32_t token_idx = threadIdx.y;
+ const int row0 = c_rows_per_block*blockIdx.x;
+ const int blocks_per_row_x = ncols_x / qk;
+ constexpr int blocks_per_iter = vdr * warp_size / qi;
+
+ const uint32_t channel_dst = blockIdx.y;
+
+ if (token_idx >= ncols_dst) {
+ return;
+ }
+
+ const uint32_t channel_x = ids[channel_dst + token_idx * ids_stride];
+ const uint32_t channel_y = fastmodulo(channel_dst, nchannels_y);
+
+ const block_q8_1 * y = ((const block_q8_1 *) vy) + channel_y*stride_channel_y + token_idx*stride_col_y;
+ const int kbx_offset = channel_x*stride_channel_x + row0*stride_row_x;
+
+ // partial sum for each thread
+ float tmp[c_rows_per_block] = {0.0f};
+
+ for (int kbx = threadIdx.x / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
+ const int kby = kbx * (qk/QK8_1);
+ const int kqs = vdr * (threadIdx.x % (qi/vdr));
+
+#pragma unroll
+ for (int i = 0; i < c_rows_per_block; ++i) {
+ tmp[i] += vec_dot_q_cuda(vx, &y[kby], kbx_offset + i*stride_row_x + kbx, kqs);
+ }
+ }
+
+ // Warp-level reduction only - no shared memory needed
+#pragma unroll
+ for (int i = 0; i < c_rows_per_block; ++i) {
+ tmp[i] = warp_reduce_sum<warp_size>(tmp[i]);
+ }
+
+ // Write results
+ if (threadIdx.x < c_rows_per_block && (c_rows_per_block == 1 || uint32_t(row0 + threadIdx.x) < nrows_x)) {
+ dst[channel_dst*stride_channel_dst + token_idx*stride_col_dst + row0 + threadIdx.x] = tmp[threadIdx.x];
+ }
+}
+
template<ggml_type type>
static std::pair<dim3, dim3> calc_launch_params(
const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens,
return {block_nums, block_dims};
}
-template<ggml_type type, int c_ncols_dst, bool is_multi_token_id = false, bool small_k = false>
+template<ggml_type type, int c_ncols_dst, bool small_k = false>
static void mul_mat_vec_q_switch_fusion(
const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
if constexpr (c_ncols_dst == 1) {
if (has_fusion) {
- mul_mat_vec_q<type, c_ncols_dst, true, is_multi_token_id, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
+ mul_mat_vec_q<type, c_ncols_dst, true, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
- mul_mat_vec_q<type, c_ncols_dst, false, is_multi_token_id, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
+ mul_mat_vec_q<type, c_ncols_dst, false, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
}
+template <ggml_type type>
+static void mul_mat_vec_q_moe_launch(
+ const void * vx, const void * vy, const int32_t * ids, float * dst,
+ const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t nrows_x,
+ const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t stride_col_dst,
+ const uint32_t stride_channel_x, const uint32_t stride_channel_y, const uint32_t stride_channel_dst,
+ const uint32_t ncols_dst, const uint32_t ids_stride,
+ const int warp_size, const int nchannels_dst, cudaStream_t stream) {
+
+ constexpr int rows_per_block = 2; // 2 gives best perf based on tuning
+ const int64_t nblocks_rows = (nrows_x + rows_per_block - 1) / rows_per_block;
+ const dim3 block_nums(nblocks_rows, nchannels_dst);
+ const dim3 block_dims(warp_size, ncols_dst);
+
+ mul_mat_vec_q_moe<type, rows_per_block><<<block_nums, block_dims, 0, stream>>>(
+ vx, vy, ids, dst, ncols_x, nchannels_y, nrows_x,
+ stride_row_x, stride_col_y, stride_col_dst,
+ stride_channel_x, stride_channel_y, stride_channel_dst,
+ ncols_dst, ids_stride);
+}
+
template <ggml_type type>
static void mul_mat_vec_q_switch_ncols_dst(
const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x);
const int device = ggml_cuda_get_device();
+ const int cc = ggml_cuda_info().devices[device].cc;
const int warp_size = ggml_cuda_info().devices[device].warp_size;
- const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
+ const mmvq_parameter_table_id table_id = get_device_table_id(cc);
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
const bool has_ids = ids != nullptr;
+ const auto should_use_small_k = [&](int c_ncols_dst) {
+ // When K is small, increase rows_per_block to match nwarps so each warp has more work to do
+ // Trigger when the full thread block covers all K blocks in a single loop iteration and few threads remain idle.
+ constexpr int qk = ggml_cuda_type_traits<type>::qk;
+ constexpr int qi = ggml_cuda_type_traits<type>::qi;
+ constexpr int vdr = get_vdr_mmvq(type);
+ const int blocks_per_row_x = ncols_x / qk;
+ const int blocks_per_iter_1warp = vdr * warp_size / qi;
+ const int nwarps = calc_nwarps(type, c_ncols_dst, table_id);
+ bool use = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp;
+
+ constexpr std::array<ggml_type, 2> iq_slow_turing = {
+ GGML_TYPE_IQ3_XXS,
+ GGML_TYPE_IQ3_S,
+ };
+ constexpr std::array<ggml_type, 8> iq_slow_other = {
+ GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M, GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS,
+ GGML_TYPE_IQ2_S, GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
+ };
+ constexpr std::array<ggml_type, 3> slow_pascal = {
+ GGML_TYPE_IQ3_S,
+ GGML_TYPE_Q2_K,
+ GGML_TYPE_Q3_K,
+ };
+
+ const bool is_nvidia_turing_plus = GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_TURING;
+ const bool is_nvidia_pascal_older = GGML_CUDA_CC_IS_NVIDIA(cc) && cc < GGML_CUDA_CC_VOLTA;
+
+ if (is_nvidia_turing_plus) {
+ if (ncols_dst == 1 &&
+ std::find(iq_slow_turing.begin(), iq_slow_turing.end(), type) != iq_slow_turing.end()) {
+ use = false;
+ }
+ } else if ((ncols_dst == 1 && std::find(iq_slow_other.begin(), iq_slow_other.end(), type) != iq_slow_other.end()) ||
+ (is_nvidia_pascal_older && std::find(slow_pascal.begin(), slow_pascal.end(), type) != slow_pascal.end()) ||
+ GGML_CUDA_CC_IS_RDNA(cc)) {
+ use = false;
+ }
+
+ return use;
+ };
+
if (has_ids && ncols_dst > 1) {
- // Multi-token MUL_MAT_ID path only - single-token goes through regular path below
- constexpr int c_ncols_dst = 1;
- std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id);
- mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
- channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
- dims.first, dims.second, 0, ids_stride, stream);
+ // Multi-token MUL_MAT_ID path - dedicated MoE kernel
+ mul_mat_vec_q_moe_launch<type>(
+ vx, vy, ids, dst, ncols_x, nchannels_y_fd, nrows_x,
+ stride_row_x, stride_col_y, stride_col_dst,
+ stride_channel_x, stride_channel_y, stride_channel_dst,
+ ncols_dst, ids_stride, warp_size, nchannels_dst, stream);
return;
}
case 1: {
constexpr int c_ncols_dst = 1;
- // When K is small, increase rows_per_block to match nwarps so each warp has more work to do
- // Trigger when the full thread block covers all K blocks in a single loop iteration and few threads remain idle.
- constexpr int qk = ggml_cuda_type_traits<type>::qk;
- constexpr int qi = ggml_cuda_type_traits<type>::qi;
- constexpr int vdr = get_vdr_mmvq(type);
- const int blocks_per_row_x = ncols_x / qk;
- const int blocks_per_iter_1warp = vdr * warp_size / qi;
- const int nwarps = calc_nwarps(type, c_ncols_dst, table_id);
- const bool use_small_k = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp;
+ bool use_small_k = should_use_small_k(c_ncols_dst);
+
if (use_small_k) {
- std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst,
- warp_size, table_id, true);
- mul_mat_vec_q_switch_fusion<type, c_ncols_dst, false, true>(
+ std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst,
+ nsamples_dst, warp_size, table_id, true);
+ mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true>(
vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
- channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
- dims.first, dims.second, 0, ids_stride, stream);
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd,
+ stride_sample_x, stride_sample_y, stride_sample_dst, dims.first, dims.second, 0, ids_stride,
+ stream);
} else {
- std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst,
- warp_size, table_id);
+ std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst,
+ nsamples_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(
vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
- channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
- dims.first, dims.second, 0, ids_stride, stream);
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, sample_ratio_fd,
+ stride_sample_x, stride_sample_y, stride_sample_dst, dims.first, dims.second, 0, ids_stride,
+ stream);
}
} break;
case 2: {