#include "mmvf.cuh"
#include "convert.cuh"
-template <typename T, typename type_acc, int ncols_dst, int block_size, bool has_fusion = false>
+template <typename T, typename type_acc, int ncols_dst, int block_size, bool has_fusion = false, bool is_multi_token_id = false>
static __global__ void mul_mat_vec_f(
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
- const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
+ const int ncols2, const uint3 nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
- const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
+ const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
+ const int ids_stride) {
const int row = blockIdx.x;
+ // for MUL_MAT_ID - blockIdx.y = n_expert_used, blockIdx.z = ncols_dst (tokens)
const int channel_dst = blockIdx.y;
- const int channel_x = ids ? ids[channel_dst] : fastdiv((uint32_t) channel_dst, channel_ratio);
- const int channel_y = ids ? channel_dst % nchannels_y : channel_dst;
- const int sample_dst = blockIdx.z;
+ const int tid = threadIdx.x;
+
+ int token_idx;
+ int channel_x;
+ int channel_y;
+ int sample_dst;
+
+ if constexpr (is_multi_token_id) {
+ // Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case
+ token_idx = blockIdx.z;
+ channel_x = ids[channel_dst + token_idx * ids_stride];
+ channel_y = fastmodulo(channel_dst, nchannels_y);
+ sample_dst = 0;
+ } else {
+ token_idx = ids ? blockIdx.z : 0;
+ channel_x = ids ? ids[blockIdx.y + token_idx * ids_stride] : fastdiv((uint32_t) channel_dst, channel_ratio);
+ channel_y = ids ? fastmodulo(blockIdx.y, nchannels_y) : channel_dst;
+ sample_dst = ids ? 0 : blockIdx.z;
+ }
+
const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio);
const int sample_y = sample_dst;
- const int tid = threadIdx.x;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
+ if constexpr (is_multi_token_id) {
+ y += token_idx*stride_col_y2*2;
+ dst += token_idx*stride_col_dst;
+ }
bool use_gate = false;
bool use_bias = false;
if (use_gate) {
gate_x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
}
+
+ const int channel_bias = ids ? channel_x : channel_dst;
+
if constexpr (has_fusion) {
- const int channel_bias = ids ? channel_x : channel_dst;
if (use_bias) {
x_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst;
}
}
}
-template<typename T, typename type_acc, int ncols_dst, int block_size>
+template<typename T, typename type_acc, int ncols_dst, int block_size, bool is_multi_token_id = false>
static void mul_mat_vec_f_switch_fusion(
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
- const int64_t ncols, const int64_t nrows,
+ const int64_t ncols, const uint3 nchannels_y,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
- const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const cudaStream_t stream) {
+ const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const int ids_stride, const cudaStream_t stream) {
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
if constexpr (ncols_dst == 1) {
if (has_fusion) {
- mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true><<<block_nums, block_dims, nbytes_shared, stream>>>
- (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+ mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
+ (x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
return;
}
}
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
- mul_mat_vec_f<T, type_acc, ncols_dst, block_size><<<block_nums, block_dims, nbytes_shared, stream>>>
- (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+ mul_mat_vec_f<T, type_acc, ncols_dst, block_size, false, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
+ (x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
}
-template <typename T, typename type_acc, int ncols_dst>
+template <typename T, typename type_acc, int ncols_dst, bool is_multi_token_id = false>
void launch_mul_mat_vec_f_cuda(
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const int64_t ncols, const int64_t nrows,
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
- cudaStream_t stream) {
+ const int64_t nsamples_or_ntokens, const int64_t ids_stride, cudaStream_t stream) {
GGML_ASSERT(ncols % 2 == 0);
GGML_ASSERT(stride_row % 2 == 0);
GGML_ASSERT(stride_col_y % 2 == 0);
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
+ const uint3 nchannels_y_fd = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0);
const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);
const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x);
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
const int nbytes_shared = warp_size*sizeof(float) + (has_fusion ? warp_size*sizeof(float) : 0);
- const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
+ const dim3 block_nums(nrows, nchannels_dst, nsamples_or_ntokens);
const dim3 block_dims(block_size_best, 1, 1);
switch (block_size_best) {
case 32: {
- mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 32>
- (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+ mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 32, is_multi_token_id>
+ (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
} break;
case 64: {
- mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 64>
- (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+ mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 64, is_multi_token_id>
+ (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
} break;
case 96: {
- mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 96>
- (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+ mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 96, is_multi_token_id>
+ (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
} break;
case 128: {
- mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 128>
- (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+ mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 128, is_multi_token_id>
+ (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
} break;
case 160: {
- mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 160>
- (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+ mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 160, is_multi_token_id>
+ (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
} break;
case 192: {
- mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 192>
- (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+ mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 192, is_multi_token_id>
+ (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
} break;
case 224: {
- mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 224>
- (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+ mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 224, is_multi_token_id>
+ (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
} break;
case 256: {
- mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 256>
- (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+ mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 256, is_multi_token_id>
+ (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
} break;
default: {
GGML_ABORT("fatal error");
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
- cudaStream_t stream) {
+ const int64_t ids_stride, cudaStream_t stream) {
+
+ const bool has_ids = ids != nullptr;
+
+ if (has_ids && ncols_dst > 1) {
+ // Multi-token MUL_MAT_ID path only - single-token goes through regular path below
+ constexpr int c_ncols_dst = 1;
+ launch_mul_mat_vec_f_cuda<T, type_acc, c_ncols_dst, true>
+ (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ ncols_dst, ids_stride, stream);
+ return;
+ }
+
+ if (has_ids) {
+ // Single-token MUL_MAT_ID path
+ constexpr int c_ncols_dst = 1;
+ launch_mul_mat_vec_f_cuda<T, type_acc, c_ncols_dst>
+ (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ ncols_dst, ids_stride, stream);
+ return;
+ }
+
switch (ncols_dst) {
case 1:
launch_mul_mat_vec_f_cuda<T, type_acc, 1>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
- stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ nsamples_dst, ids_stride, stream);
break;
case 2:
launch_mul_mat_vec_f_cuda<T, type_acc, 2>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
- stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ nsamples_dst, ids_stride, stream);
break;
case 3:
launch_mul_mat_vec_f_cuda<T, type_acc, 3>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
- stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ nsamples_dst, ids_stride, stream);
break;
case 4:
launch_mul_mat_vec_f_cuda<T, type_acc, 4>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
- stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ nsamples_dst, ids_stride, stream);
break;
case 5:
launch_mul_mat_vec_f_cuda<T, type_acc, 5>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
- stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ nsamples_dst, ids_stride, stream);
break;
case 6:
launch_mul_mat_vec_f_cuda<T, type_acc, 6>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
- stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ nsamples_dst, ids_stride, stream);
break;
case 7:
launch_mul_mat_vec_f_cuda<T, type_acc, 7>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
- stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ nsamples_dst, ids_stride, stream);
break;
case 8:
launch_mul_mat_vec_f_cuda<T, type_acc, 8>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
- stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ nsamples_dst, ids_stride, stream);
break;
default:
GGML_ABORT("fatal error");
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
- enum ggml_prec prec, cudaStream_t stream) {
+ const int64_t ids_stride, enum ggml_prec prec, cudaStream_t stream) {
if constexpr(std::is_same_v<T, half>) {
if (prec == GGML_PREC_DEFAULT) {
mul_mat_vec_f_cuda_switch_ncols_dst<T, half>
(x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
- stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
return;
}
}
mul_mat_vec_f_cuda_switch_ncols_dst<T, float>
(x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
- stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
}
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
const size_t ts_src1 = ggml_type_size(src1->type);
const size_t ts_dst = ggml_type_size(dst->type);
- GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1.
+ GGML_ASSERT(!ids || ne12 <= MMVF_MAX_BATCH_SIZE);
GGML_ASSERT(ne13 == ne3);
GGML_ASSERT( nb00 == ts_src0);
const int64_t ncols_dst = ids ? ne2 : ne1;
const int64_t nchannels_y = ids ? ne11 : ne12;
const int64_t nchannels_dst = ids ? ne1 : ne2;
+ const int64_t stride_col_dst = ids ? s2 : s1;
+ const int64_t stride_col_y = ids ? s12 : s11;
const int64_t stride_channel_dst = ids ? s1 : s2;
const int64_t stride_channel_y = ids ? s11 : s12;
- GGML_ASSERT(!ids || ncols_dst == 1);
+ const int64_t ids_stride = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
switch (src0->type) {
case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0->data;
- mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
+ mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
- ne03, ne3, s03, s13, s3, prec, ctx.stream());
+ ne03, ne3, s03, s13, s3, ids_stride, prec, ctx.stream());
} break;
case GGML_TYPE_F16: {
const half * src0_d = (const half *) src0->data;
- mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
+ mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
- ne03, ne3, s03, s13, s3, prec, ctx.stream());
+ ne03, ne3, s03, s13, s3, ids_stride, prec, ctx.stream());
} break;
case GGML_TYPE_BF16: {
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
- mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
+ mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
- ne03, ne3, s03, s13, s3, prec, ctx.stream());
+ ne03, ne3, s03, s13, s3, ids_stride, prec, ctx.stream());
} break;
default:
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
const float * src0_d = (const float *) src0_dd_i;
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream);
} break;
case GGML_TYPE_F16: {
const half * src0_d = (const half *) src0_dd_i;
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream);
} break;
case GGML_TYPE_BF16: {
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream);
} break;
default:
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
return 1;
}
-// tell the compiler to use as many registers as it wants, see nwarps definition below
-template <ggml_type type, int ncols_dst, bool has_fusion>
+template <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false>
__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
- const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) {
+ const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst,
+ const uint32_t ids_stride) {
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qi = ggml_cuda_type_traits<type>::qi;
const int blocks_per_row_x = ncols_x / qk;
constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
- // The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1.
const uint32_t channel_dst = blockIdx.y;
- const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio);
- const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
- const uint32_t sample_dst = blockIdx.z;
+
+ uint32_t token_idx = 0;
+ uint32_t channel_x;
+ uint32_t channel_y;
+ uint32_t sample_dst;
+
+ if constexpr (is_multi_token_id) {
+ // Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case
+ token_idx = blockIdx.z;
+ channel_x = ids[channel_dst + token_idx * ids_stride];
+ channel_y = fastmodulo(channel_dst, nchannels_y);
+ sample_dst = 0;
+ } else {
+ channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio);
+ channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
+ sample_dst = blockIdx.z;
+ }
+
const uint32_t sample_x = fastdiv(sample_dst, sample_ratio);
const uint32_t sample_y = sample_dst;
active_glu = fusion.glu_op;
}
- const uint32_t channel_bias = ids ? channel_x : channel_dst;
float x_biases[ncols_dst] = { 0.0f };
float gate_biases[ncols_dst] = { 0.0f };
if constexpr (has_fusion) {
+ const uint32_t channel_bias = ids ? channel_x : channel_dst;
if (use_bias) {
x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
// 1. Hide latency by prefetching bias and gate here
float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}};
const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y;
+ if constexpr (is_multi_token_id) {
+ y += token_idx*stride_col_y;
+ }
const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x;
for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0;
+ if constexpr (is_multi_token_id) {
+ dst += token_idx*stride_col_dst;
+ }
+
// sum up partial sums and write back result
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
}
static std::pair<dim3, dim3> calc_launch_params(
- const int ncols_dst, const int nrows_x, const int nchannels_y, const int nsamples_y,
+ const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens,
const int warp_size, const mmvq_parameter_table_id table_id) {
const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id);
- const dim3 block_nums(nblocks, nchannels_y, nsamples_y);
+ const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens);
const dim3 block_dims(warp_size, calc_nwarps(ncols_dst, table_id), 1);
return {block_nums, block_dims};
}
-template<ggml_type type, int c_ncols_dst>
+template<ggml_type type, int c_ncols_dst, bool is_multi_token_id = false>
static void mul_mat_vec_q_switch_fusion(
const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst,
- const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared, cudaStream_t stream) {
+ const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared,
+ const uint32_t ids_stride, cudaStream_t stream) {
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
if constexpr (c_ncols_dst == 1) {
if (has_fusion) {
- mul_mat_vec_q<type, c_ncols_dst, true><<<block_nums, block_dims, nbytes_shared, stream>>>
+ mul_mat_vec_q<type, c_ncols_dst, true, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
return;
}
}
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
- mul_mat_vec_q<type, c_ncols_dst, false><<<block_nums, block_dims, nbytes_shared, stream>>>
+ mul_mat_vec_q<type, c_ncols_dst, false, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
}
template <ggml_type type>
const int nchannels_x, const int nchannels_y, const int nchannels_dst,
const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
- cudaStream_t stream) {
+ const int ids_stride, cudaStream_t stream) {
GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE);
const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
+ const bool has_ids = ids != nullptr;
+
+ if (has_ids && ncols_dst > 1) {
+ // Multi-token MUL_MAT_ID path only - single-token goes through regular path below
+ constexpr int c_ncols_dst = 1;
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id);
+ mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
+ dims.first, dims.second, 0, ids_stride, stream);
+ return;
+ }
- GGML_ASSERT(!ids || ncols_dst == 1);
switch (ncols_dst) {
case 1: {
constexpr int c_ncols_dst = 1;
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
- dims.first, dims.second, 0, stream);
+ dims.first, dims.second, 0, ids_stride, stream);
} break;
case 2: {
constexpr int c_ncols_dst = 2;
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
- dims.first, dims.second, 0, stream);
+ dims.first, dims.second, 0, ids_stride, stream);
} break;
case 3: {
constexpr int c_ncols_dst = 3;
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
- dims.first, dims.second, 0, stream);
+ dims.first, dims.second, 0, ids_stride, stream);
} break;
case 4: {
constexpr int c_ncols_dst = 4;
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
- dims.first, dims.second, 0, stream);
+ dims.first, dims.second, 0, ids_stride, stream);
} break;
case 5: {
constexpr int c_ncols_dst = 5;
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
- dims.first, dims.second, 0, stream);
+ dims.first, dims.second, 0, ids_stride, stream);
} break;
case 6: {
constexpr int c_ncols_dst = 6;
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
- dims.first, dims.second, 0, stream);
+ dims.first, dims.second, 0, ids_stride, stream);
} break;
case 7: {
constexpr int c_ncols_dst = 7;
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
- dims.first, dims.second, 0, stream);
+ dims.first, dims.second, 0, ids_stride, stream);
} break;
case 8: {
constexpr int c_ncols_dst = 8;
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
- dims.first, dims.second, 0, stream);
+ dims.first, dims.second, 0, ids_stride, stream);
} break;
default:
GGML_ABORT("fatal error");
const int nchannels_x, const int nchannels_y, const int nchannels_dst,
const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
- cudaStream_t stream) {
+ const int ids_stride, cudaStream_t stream) {
switch (type_x) {
case GGML_TYPE_Q4_0:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_0>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_Q4_1:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_1>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_Q5_0:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_0>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_Q5_1:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_1>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_Q8_0:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q8_0>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_MXFP4:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_MXFP4>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_Q2_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_Q3_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q3_K>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_Q4_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_K>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_Q5_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_K>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_Q6_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q6_K>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_IQ2_XXS:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XXS>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_IQ2_XS:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XS>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_IQ2_S:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_S>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_IQ3_XXS:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_XXS>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_IQ1_S:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_S>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_IQ1_M:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_M>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_IQ4_NL:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_NL>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_IQ4_XS:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_XS>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
case GGML_TYPE_IQ3_S:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_S>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break;
default:
GGML_ABORT("fatal error");
GGML_ASSERT( nb0 == ts_dst);
GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
- GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1.
+ GGML_ASSERT(!ids || ne12 <= MMVQ_MAX_BATCH_SIZE);
const float * src1_d = (const float *) src1->data;
const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
const int64_t stride_channel_dst = ids ? s1 : s2;
const int64_t stride_channel_y = ids ? s11 : s12;
+ const int64_t ids_stride = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
+
mul_mat_vec_q_switch_type(
src0->data, src0->type, src1_q8_1.get(), ids_d, fusion_local, dst_d, ne00,
ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
- ne03, ne3, s03, s13, s3, stream);
+ ne03, ne3, s03, s13, s3, ids_stride, stream);
}
void ggml_cuda_op_mul_mat_vec_q(
ggml_cuda_mm_fusion_args_device fusion_local{};
mul_mat_vec_q_switch_type(
src0_dd_i, src0->type, src1_ddq_i, nullptr, fusion_local, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst,
- 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, stream);
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, stream);
GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_ncols, src1_padded_row_size);
}