}
}
-static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
+static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0: return VDR_Q4_0_Q8_1_MMVQ;
case GGML_TYPE_Q4_1: return VDR_Q4_1_Q8_1_MMVQ;
return 1;
}
-static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id) {
+static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id, bool small_k = false, int nwarps = 1) {
if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) {
switch (ncols_dst) {
case 1:
- return 1;
+ return small_k ? nwarps : 1;
case 2:
case 3:
case 4:
return 1;
}
-template <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false>
+template <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false, bool small_k = false>
__launch_bounds__(calc_nwarps(type, ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
constexpr int vdr = get_vdr_mmvq(type);
constexpr mmvq_parameter_table_id table_id = get_device_table_id();
constexpr int nwarps = calc_nwarps(type, ncols_dst, table_id);
- constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id);
+ constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps);
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
template<ggml_type type>
static std::pair<dim3, dim3> calc_launch_params(
const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens,
- const int warp_size, const mmvq_parameter_table_id table_id) {
- const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id);
+ const int warp_size, const mmvq_parameter_table_id table_id, const bool small_k = false) {
+ const int nwarps = calc_nwarps(type, ncols_dst, table_id);
+ const int rpb = calc_rows_per_block(ncols_dst, table_id, small_k, nwarps);
+ const int64_t nblocks = (nrows_x + rpb - 1) / rpb;
const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens);
- const dim3 block_dims(warp_size, calc_nwarps(type, ncols_dst, table_id), 1);
+ const dim3 block_dims(warp_size, nwarps, 1);
return {block_nums, block_dims};
}
-template<ggml_type type, int c_ncols_dst, bool is_multi_token_id = false>
+template<ggml_type type, int c_ncols_dst, bool is_multi_token_id = false, bool small_k = false>
static void mul_mat_vec_q_switch_fusion(
const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
if constexpr (c_ncols_dst == 1) {
if (has_fusion) {
- mul_mat_vec_q<type, c_ncols_dst, true, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
+ mul_mat_vec_q<type, c_ncols_dst, true, is_multi_token_id, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
- mul_mat_vec_q<type, c_ncols_dst, false, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
+ mul_mat_vec_q<type, c_ncols_dst, false, is_multi_token_id, small_k><<<block_nums, block_dims, nbytes_shared, stream>>>
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
switch (ncols_dst) {
case 1: {
constexpr int c_ncols_dst = 1;
- std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
- mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
- channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
- dims.first, dims.second, 0, ids_stride, stream);
+
+ // When K is small, increase rows_per_block to match nwarps so each warp has more work to do
+ // Trigger when the full thread block covers all K blocks in a single loop iteration and few threads remain idle.
+ constexpr int qk = ggml_cuda_type_traits<type>::qk;
+ constexpr int qi = ggml_cuda_type_traits<type>::qi;
+ constexpr int vdr = get_vdr_mmvq(type);
+ const int blocks_per_row_x = ncols_x / qk;
+ const int blocks_per_iter_1warp = vdr * warp_size / qi;
+ const int nwarps = calc_nwarps(type, c_ncols_dst, table_id);
+ const bool use_small_k = nwarps > 1 && blocks_per_row_x < nwarps * blocks_per_iter_1warp;
+ if (use_small_k) {
+ std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst,
+ warp_size, table_id, true);
+ mul_mat_vec_q_switch_fusion<type, c_ncols_dst, false, true>(
+ vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
+ dims.first, dims.second, 0, ids_stride, stream);
+ } else {
+ std::pair<dim3, dim3> dims = calc_launch_params<type>(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst,
+ warp_size, table_id);
+ mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(
+ vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
+ dims.first, dims.second, 0, ids_stride, stream);
+ }
} break;
case 2: {
constexpr int c_ncols_dst = 2;