template <typename T, typename type_acc, int block_size>
static __global__ void mul_mat_vec(
- const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
+ const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
+ const int64_t ncols2, const int64_t nchannels_y, const int64_t stride_row,
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) {
- const int64_t row = blockIdx.x;
- const int64_t channel = blockIdx.y;
- const int64_t sample = blockIdx.z;
- const int tid = threadIdx.x;
- constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-
- x += (sample/sample_ratio)*stride_sample_x + (channel/channel_ratio)*stride_channel_x + row*stride_row;
- y += sample *stride_sample_y + channel *stride_channel_y;
- dst += sample *stride_sample_dst + channel *stride_channel_dst;
+ const int64_t row = blockIdx.x;
+ const int64_t channel_dst = blockIdx.y;
+ const int64_t channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio;
+ const int64_t channel_y = ids ? channel_dst % nchannels_y : channel_dst;
+ const int64_t sample_dst = blockIdx.z;
+ const int64_t sample_x = sample_dst / sample_ratio;
+ const int64_t sample_y = sample_dst;
+ const int tid = threadIdx.x;
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+ x += sample_x *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
+ y += sample_y *stride_sample_y + channel_y *stride_channel_y;
+ dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst;
const float2 * y2 = (const float2 *) y;
float sumf = 0.0f;
- if constexpr (std::is_same<T, half>::value) {
+ if constexpr (std::is_same<T, float>::value) {
+ const float2 * x2 = (const float2 *) x;
+
+ for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
+ const float2 tmpx = x2[col2];
+ const float2 tmpy = y2[col2];
+ sumf += tmpx.x*tmpy.x;
+ sumf += tmpx.y*tmpy.y;
+ }
+ } else if constexpr (std::is_same<T, half>::value) {
const half2 * x2 = (const half2 *) x;
if (std::is_same<type_acc, float>::value) {
- sumf = 0.0f;
-
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
const float2 tmpx = __half22float2(x2[col2]);
const float2 tmpy = y2[col2];
}
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
const int * x2 = (const int *) x;
- sumf = 0.0f;
-
for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
const int tmpx = x2[col2];
const float2 tmpy = y2[col2];
template <typename T, typename type_acc>
static void launch_mul_mat_vec_cuda(
- const T * x, const float * y, float * dst,
- const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
+ const T * x, const float * y, const int32_t * ids, float * dst,
+ const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
- const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream) {
GGML_ASSERT(ncols % 2 == 0);
GGML_ASSERT(stride_row % 2 == 0);
- GGML_ASSERT(nchannels_y % nchannels_x == 0);
- GGML_ASSERT(nsamples_y % nsamples_x == 0);
- const int64_t channel_ratio = nchannels_y / nchannels_x;
- const int64_t sample_ratio = nsamples_y / nsamples_x;
+ GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
+ GGML_ASSERT( nsamples_dst % nsamples_x == 0);
+ const int64_t channel_ratio = nchannels_dst / nchannels_x;
+ const int64_t sample_ratio = nsamples_dst / nsamples_x;
int device;
int warp_size;
}
const int smem = warp_size*sizeof(float);
- const dim3 block_nums(nrows, nchannels_y, nsamples_y);
+ const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
const dim3 block_dims(block_size_best, 1, 1);
switch (block_size_best) {
case 32: {
mul_mat_vec<T, type_acc, 32><<<block_nums, block_dims, smem, stream>>>
- (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
+ stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 64: {
mul_mat_vec<T, type_acc, 64><<<block_nums, block_dims, smem, stream>>>
- (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
+ stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 96: {
mul_mat_vec<T, type_acc, 96><<<block_nums, block_dims, smem, stream>>>
- (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
+ stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 128: {
mul_mat_vec<T, type_acc, 128><<<block_nums, block_dims, smem, stream>>>
- (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
+ stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 160: {
mul_mat_vec<T, type_acc, 160><<<block_nums, block_dims, smem, stream>>>
- (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
+ stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 192: {
mul_mat_vec<T, type_acc, 192><<<block_nums, block_dims, smem, stream>>>
- (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
+ stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 224: {
mul_mat_vec<T, type_acc, 224><<<block_nums, block_dims, smem, stream>>>
- (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
+ stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 256: {
mul_mat_vec<T, type_acc, 256><<<block_nums, block_dims, smem, stream>>>
- (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
+ stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
default: {
GGML_ABORT("fatal error");
template<typename T>
static void mul_mat_vec_cuda(
- const T * x, const float * y, float * dst,
- const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
+ const T * x, const float * y, const int32_t * ids, float * dst,
+ const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
- const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
enum ggml_prec prec, cudaStream_t stream) {
- switch (prec) {
- case GGML_PREC_DEFAULT: {
+ if constexpr(std::is_same<T, half>::value) {
+ if (prec == GGML_PREC_DEFAULT) {
launch_mul_mat_vec_cuda<T, half>
- (x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
- } break;
- case GGML_PREC_F32: {
- launch_mul_mat_vec_cuda<T, float>
- (x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
- } break;
+ (x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ return;
+ }
}
+ launch_mul_mat_vec_cuda<T, float>
+ (x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
}
-void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
- GGML_ASSERT(dst->type == GGML_TYPE_F32);
+void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
+ GGML_ASSERT( src1->type == GGML_TYPE_F32);
+ GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_TENSOR_BINARY_OP_LOCALS;
const size_t ts_src1 = ggml_type_size(src1->type);
const size_t ts_dst = ggml_type_size(dst->type);
- GGML_ASSERT(ne11 == 1);
- GGML_ASSERT(ne12 == ne2);
+ GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1.
GGML_ASSERT(ne13 == ne3);
- GGML_ASSERT(nb00 == ts_src0);
- GGML_ASSERT(nb10 == ts_src1);
- GGML_ASSERT(nb0 == ts_dst);
+ GGML_ASSERT( nb00 == ts_src0);
+ GGML_ASSERT( nb10 == ts_src1);
+ GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
+ GGML_ASSERT( nb0 == ts_dst);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
- const float * src1_d = (const float *) src1->data;
- float * dst_d = (float *) dst->data;
+ const float * src1_d = (const float *) src1->data;
+ const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
+ float * dst_d = (float *) dst->data;
const int64_t s01 = src0->nb[1] / ts_src0;
+ const int64_t s11 = src1->nb[1] / ts_src1;
+ const int64_t s1 = dst->nb[1] / ts_dst;
const int64_t s02 = src0->nb[2] / ts_src0;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s2 = dst->nb[2] / ts_dst;
const int64_t s13 = src1->nb[3] / ts_src1;
const int64_t s3 = dst->nb[3] / ts_dst;
+ // For MUL_MAT_ID the memory layout is different than for MUL_MAT:
+ const int64_t ncols_dst = ids ? ne2 : ne1;
+ const int64_t nchannels_y = ids ? ne11 : ne12;
+ const int64_t nchannels_dst = ids ? ne1 : ne2;
+ const int64_t stride_channel_dst = ids ? s1 : s2;
+ const int64_t stride_channel_y = ids ? s11 : s12;
+
+ GGML_ASSERT(ncols_dst == 1);
+
switch (src0->type) {
+ case GGML_TYPE_F32: {
+ const float * src0_d = (const float *) src0->data;
+ mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
+ ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
+ ne03, ne3, s03, s13, s3, prec, ctx.stream());
+ } break;
case GGML_TYPE_F16: {
const half * src0_d = (const half *) src0->data;
- mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream());
+ mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
+ ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
+ ne03, ne3, s03, s13, s3, prec, ctx.stream());
} break;
case GGML_TYPE_BF16: {
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
- mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream());
+ mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
+ ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
+ ne03, ne3, s03, s13, s3, prec, ctx.stream());
} break;
default:
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
const int64_t stride_row = ne00;
const int64_t nchannels_x = 1;
const int64_t nchannels_y = 1;
+ const int64_t nchannels_dst = 1;
const int64_t stride_channel_x = 0;
const int64_t stride_channel_y = 0;
const int64_t stride_channel_dst = 0;
const int64_t nsamples_x = 1;
- const int64_t nsamples_y = 1;
+ const int64_t nsamples_dst = 1;
const int64_t stride_sample_x = 0;
const int64_t stride_sample_y = 0;
const int64_t stride_sample_dst = 0;
switch (src0->type) {
+ case GGML_TYPE_F32: {
+ const float * src0_d = (const float *) src0_dd_i;
+ mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
+ } break;
case GGML_TYPE_F16: {
const half * src0_d = (const half *) src0_dd_i;
- mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
- nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
+ mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
} break;
case GGML_TYPE_BF16: {
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
- mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
- nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
+ mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
} break;
default:
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
#include "mmvq.cuh"
+#include "quantize.cuh"
#include "vecdotq.cuh"
+#include <cstdint>
+
typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
return MMVQ_PARAMETERS_GENERIC;
}
-static constexpr __host__ __device__ int calc_nwarps(int ncols_y, mmvq_parameter_table_id table_id) {
+static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) {
if (table_id == MMVQ_PARAMETERS_GENERIC) {
- switch (ncols_y) {
+ switch (ncols_dst) {
case 1:
case 2:
case 3:
return 1;
}
} else if (table_id == MMVQ_PARAMETERS_GCN) {
- switch (ncols_y) {
+ switch (ncols_dst) {
case 1:
case 2:
case 3:
return 1;
}
-static constexpr __host__ __device__ int calc_rows_per_block(int ncols_y, int table_id) {
+static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int table_id) {
if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) {
- switch (ncols_y) {
+ switch (ncols_dst) {
case 1:
return 1;
case 2:
return 1;
}
-template <ggml_type type, int ncols_y>
+template <ggml_type type, int ncols_dst>
// tell the compiler to use as many registers as it wants, see nwarps definition below
-__launch_bounds__(calc_nwarps(ncols_y, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
+__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mul_mat_vec_q(
- const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
- const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
+ const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst,
+ const int ncols_x, const int nchannels_y, const int stride_row_x, const int stride_col_y, const int stride_col_dst,
+ const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
+ const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qi = ggml_cuda_type_traits<type>::qi;
constexpr int vdr = get_vdr_mmvq(type);
constexpr mmvq_parameter_table_id table_id = get_device_table_id();
- constexpr int nwarps = calc_nwarps(ncols_y, table_id);
- constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_y, table_id);
+ constexpr int nwarps = calc_nwarps(ncols_dst, table_id);
+ constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_dst, table_id);
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
const int tid = warp_size*threadIdx.y + threadIdx.x;
const int row0 = rows_per_cuda_block*blockIdx.x;
const int blocks_per_row_x = ncols_x / qk;
- const int blocks_per_col_y = nrows_y / QK8_1;
constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
+ // The MUL_MAT_ID code path with ids != nullptr is only implemetned for ncols_dst == 1.
+ const int channel_dst = blockIdx.y;
+ const int channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : channel_dst / channel_ratio;
+ const int channel_y = ncols_dst == 1 && ids ? channel_dst % nchannels_y : channel_dst;
+ const int sample_dst = blockIdx.z;
+ const int sample_x = sample_dst / sample_ratio;
+ const int sample_y = sample_dst;
+
// partial sum for each thread
- float tmp[ncols_y][rows_per_cuda_block] = {{0.0f}};
+ float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}};
- const block_q8_1 * y = (const block_q8_1 *) vy;
+ const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y;
+ const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x;
for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx
const int kqs = vdr * (tid % (qi/vdr));
#pragma unroll
- for (int j = 0; j < ncols_y; ++j) {
+ for (int j = 0; j < ncols_dst; ++j) {
#pragma unroll
for (int i = 0; i < rows_per_cuda_block; ++i) {
- tmp[j][i] += vec_dot_q_cuda(vx, &y[j*blocks_per_col_y + kby], (row0 + i)*blocks_per_row_x + kbx, kqs);
+ tmp[j][i] += vec_dot_q_cuda(
+ vx, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs);
}
}
}
- __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][warp_size];
+ __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
if (threadIdx.y > 0) {
#pragma unroll
- for (int j = 0; j < ncols_y; ++j) {
+ for (int j = 0; j < ncols_dst; ++j) {
#pragma unroll
for (int i = 0; i < rows_per_cuda_block; ++i) {
tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i];
return;
}
+ dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0;
+
// sum up partial sums and write back result
#pragma unroll
- for (int j = 0; j < ncols_y; ++j) {
+ for (int j = 0; j < ncols_dst; ++j) {
#pragma unroll
for (int i = 0; i < rows_per_cuda_block; ++i) {
#pragma unroll
tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]);
}
- if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < (unsigned)nrows_dst)) {
- dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
+ if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + int(threadIdx.x) < stride_col_dst)) {
+ dst[j*stride_col_dst + threadIdx.x] = tmp[j][threadIdx.x];
}
}
-
- GGML_UNUSED(nrows_x);
}
-static std::pair<dim3, dim3> calc_launch_params(const int ncols_y, const int nrows_x, const int warp_size, const mmvq_parameter_table_id table_id) {
- const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_y, table_id) - 1) / calc_rows_per_block(ncols_y, table_id);
- const dim3 block_nums(nblocks, 1, 1);
- const dim3 block_dims(warp_size, calc_nwarps(ncols_y, table_id), 1);
+static std::pair<dim3, dim3> calc_launch_params(
+ const int ncols_dst, const int nrows_x, const int nchannels_y, const int nsamples_y,
+ const int warp_size, const mmvq_parameter_table_id table_id) {
+ const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id);
+ const dim3 block_nums(nblocks, nchannels_y, nsamples_y);
+ const dim3 block_dims(warp_size, calc_nwarps(ncols_dst, table_id), 1);
return {block_nums, block_dims};
}
template <ggml_type type>
-static void mul_mat_vec_q_cuda(
- const void * vx, const void * vy, float * dst,
- const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
+static void mul_mat_vec_q_switch_ncols_dst(
+ const void * vx, const void * vy, const int32_t * ids, float * dst,
+ const int ncols_x, const int nrows_x, const int ncols_dst,
+ const int stride_row_x, const int stride_col_y, const int stride_col_dst,
+ const int nchannels_x, const int nchannels_y, const int nchannels_dst,
+ const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
+ const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
+ cudaStream_t stream) {
GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
- GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
+ GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE);
+
+ const int channel_ratio = nchannels_dst / nchannels_x;
+ const int sample_ratio = nsamples_dst / nsamples_x;
const int device = ggml_cuda_get_device();
const int warp_size = ggml_cuda_info().devices[device].warp_size;
const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
- switch (ncols_y) {
+ GGML_ASSERT(!ids || ncols_dst == 1);
+ switch (ncols_dst) {
case 1:
{
- constexpr int c_ncols_y = 1;
- std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
- mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+ constexpr int c_ncols_dst = 1;
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+ mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
+ (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
break;
}
case 2:
{
- constexpr int c_ncols_y = 2;
- std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
- mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+ constexpr int c_ncols_dst = 2;
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+ mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
+ (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
break;
}
case 3:
{
- constexpr int c_ncols_y = 3;
- std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
- mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+ constexpr int c_ncols_dst = 3;
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+ mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
+ (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
break;
}
case 4:
{
- constexpr int c_ncols_y = 4;
- std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
- mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+ constexpr int c_ncols_dst = 4;
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+ mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
+ (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
break;
}
case 5:
{
- constexpr int c_ncols_y = 5;
- std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
- mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+ constexpr int c_ncols_dst = 5;
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+ mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
+ (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
break;
}
case 6:
{
- constexpr int c_ncols_y = 6;
- std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
- mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+ constexpr int c_ncols_dst = 6;
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+ mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
+ (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
break;
}
case 7:
{
- constexpr int c_ncols_y = 7;
- std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
- mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+ constexpr int c_ncols_dst = 7;
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+ mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
+ (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
break;
}
case 8:
{
- constexpr int c_ncols_y = 8;
- std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
- mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
+ constexpr int c_ncols_dst = 8;
+ std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
+ mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
+ (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
break;
}
default:
}
}
-static void mul_mat_vec_q4_0_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
- const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
- mul_mat_vec_q_cuda<GGML_TYPE_Q4_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_q4_1_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
- const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
- mul_mat_vec_q_cuda<GGML_TYPE_Q4_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_q5_0_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
- const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
- mul_mat_vec_q_cuda<GGML_TYPE_Q5_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_q5_1_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
- const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
- mul_mat_vec_q_cuda<GGML_TYPE_Q5_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_q8_0_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
- const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
- mul_mat_vec_q_cuda<GGML_TYPE_Q8_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_q2_K_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
- const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
- mul_mat_vec_q_cuda<GGML_TYPE_Q2_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_q3_K_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
- const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
- mul_mat_vec_q_cuda<GGML_TYPE_Q3_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_q4_K_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
- const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
- mul_mat_vec_q_cuda<GGML_TYPE_Q4_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_q5_K_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
- const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
- mul_mat_vec_q_cuda<GGML_TYPE_Q5_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_q6_K_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
- const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
- mul_mat_vec_q_cuda<GGML_TYPE_Q6_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_iq2_xxs_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
- const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
- mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_iq2_xs_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
- const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
- mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_iq2_s_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
- const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
- mul_mat_vec_q_cuda<GGML_TYPE_IQ2_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_iq3_xxs_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
- const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
- mul_mat_vec_q_cuda<GGML_TYPE_IQ3_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_iq1_s_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
- const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
- mul_mat_vec_q_cuda<GGML_TYPE_IQ1_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_iq1_m_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
- const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
- mul_mat_vec_q_cuda<GGML_TYPE_IQ1_M>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_iq4_nl_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
- const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
- mul_mat_vec_q_cuda<GGML_TYPE_IQ4_NL>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_iq4_xs_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
- const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
- mul_mat_vec_q_cuda<GGML_TYPE_IQ4_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-static void mul_mat_vec_iq3_s_q8_1_cuda(
- const void * vx, const void * vy, float * dst,
- const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
-
- mul_mat_vec_q_cuda<GGML_TYPE_IQ3_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
-}
-
-void ggml_cuda_op_mul_mat_vec_q(
- ggml_backend_cuda_context & ctx,
- const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
- const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
- const int64_t src1_padded_row_size, cudaStream_t stream) {
-
- const int64_t ne00 = src0->ne[0];
- const int64_t row_diff = row_high - row_low;
-
- const int64_t ne10 = src1->ne[0];
- GGML_ASSERT(ne10 % QK8_1 == 0);
-
- const int64_t ne0 = dst->ne[0];
-
- int id = ggml_cuda_get_device();
-
- // the main device has a larger memory buffer to hold the results from all GPUs
- // nrows_dst == nrows of the matrix that the kernel writes into
- const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
-
- switch (src0->type) {
+static void mul_mat_vec_q_switch_type(
+ const void * vx, const ggml_type type_x, const void * vy, const int32_t * ids, float * dst,
+ const int ncols_x, const int nrows_x, const int ncols_dst,
+ const int stride_row_x, const int stride_col_y, const int stride_col_dst,
+ const int nchannels_x, const int nchannels_y, const int nchannels_dst,
+ const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
+ const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
+ cudaStream_t stream) {
+ switch (type_x) {
case GGML_TYPE_Q4_0:
- mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_0>
+ (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ stream);
break;
case GGML_TYPE_Q4_1:
- mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_1>
+ (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ stream);
break;
case GGML_TYPE_Q5_0:
- mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_0>
+ (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ stream);
break;
case GGML_TYPE_Q5_1:
- mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_1>
+ (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ stream);
break;
case GGML_TYPE_Q8_0:
- mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q8_0>
+ (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ stream);
break;
case GGML_TYPE_Q2_K:
- mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
+ (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ stream);
break;
case GGML_TYPE_Q3_K:
- mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q3_K>
+ (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ stream);
break;
case GGML_TYPE_Q4_K:
- mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_K>
+ (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ stream);
break;
case GGML_TYPE_Q5_K:
- mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_K>
+ (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ stream);
break;
case GGML_TYPE_Q6_K:
- mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q6_K>
+ (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ stream);
break;
case GGML_TYPE_IQ2_XXS:
- mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XXS>
+ (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ stream);
break;
case GGML_TYPE_IQ2_XS:
- mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XS>
+ (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ stream);
break;
case GGML_TYPE_IQ2_S:
- mul_mat_vec_iq2_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_S>
+ (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ stream);
break;
case GGML_TYPE_IQ3_XXS:
- mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_XXS>
+ (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ stream);
break;
case GGML_TYPE_IQ1_S:
- mul_mat_vec_iq1_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_S>
+ (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ stream);
break;
case GGML_TYPE_IQ1_M:
- mul_mat_vec_iq1_m_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_M>
+ (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ stream);
break;
case GGML_TYPE_IQ4_NL:
- mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_NL>
+ (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ stream);
break;
case GGML_TYPE_IQ4_XS:
- mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_XS>
+ (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ stream);
break;
case GGML_TYPE_IQ3_S:
- mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_S>
+ (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ stream);
break;
default:
GGML_ABORT("fatal error");
break;
}
+}
+
+void ggml_cuda_mul_mat_vec_q(
+ ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
+ GGML_ASSERT( src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID.
+
+ GGML_TENSOR_BINARY_OP_LOCALS;
+
+ cudaStream_t stream = ctx.stream();
+
+ const size_t ts_src0 = ggml_type_size(src0->type);
+ const size_t ts_src1 = ggml_type_size(src1->type);
+ const size_t ts_dst = ggml_type_size(dst->type);
+
+ GGML_ASSERT( nb00 == ts_src0);
+ GGML_ASSERT( nb10 == ts_src1);
+ GGML_ASSERT( nb0 == ts_dst);
+ GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
+
+ GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1.
+
+ const float * src1_d = (const float *) src1->data;
+ const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
+ float * dst_d = (float *) dst->data;
+
+ const int64_t ne10_padded = GGML_PAD(ne10, MATRIX_ROW_PADDING);
+ ggml_cuda_pool_alloc<char> src1_q8_1(ctx.pool(), ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1);
+ {
+ const int64_t s11 = src1->nb[1] / ts_src1;
+ const int64_t s12 = src1->nb[2] / ts_src1;
+ const int64_t s13 = src1->nb[3] / ts_src1;
+ quantize_row_q8_1_cuda(src1_d, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
+ }
+
+ const int64_t s01 = src0->nb[1] / ts_src0;
+ const int64_t s11 = ne10_padded / QK8_1;
+ const int64_t s1 = dst->nb[1] / ts_dst;
+ const int64_t s02 = src0->nb[2] / ts_src0;
+ const int64_t s2 = dst->nb[2] / ts_dst;
+ const int64_t s03 = src0->nb[3] / ts_src0;
+ const int64_t s3 = dst->nb[3] / ts_dst;
+
+ const int64_t s12 = ne11*s11;
+ const int64_t s13 = ne12*s12;
+
+ // For MUL_MAT_ID the memory layout is different than for MUL_MAT:
+ const int64_t ncols_dst = ids ? ne2 : ne1;
+ const int64_t nchannels_y = ids ? ne11 : ne12;
+ const int64_t nchannels_dst = ids ? ne1 : ne2;
+ const int64_t stride_col_dst = ids ? s2 : s1;
+ const int64_t stride_col_y = ids ? s12 : s11;
+ const int64_t stride_channel_dst = ids ? s1 : s2;
+ const int64_t stride_channel_y = ids ? s11 : s12;
+
+ mul_mat_vec_q_switch_type(
+ src0->data, src0->type, src1_q8_1.get(), ids_d, dst_d, ne00,
+ ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
+ ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
+ ne03, ne3, s03, s13, s3, stream);
+}
+
+void ggml_cuda_op_mul_mat_vec_q(
+ ggml_backend_cuda_context & ctx,
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
+ const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
+ const int64_t src1_padded_row_size, cudaStream_t stream) {
+
+ const int64_t ne00 = src0->ne[0];
+ const int64_t row_diff = row_high - row_low;
+
+ const int64_t ne10 = src1->ne[0];
+ GGML_ASSERT(ne10 % QK8_1 == 0);
+
+ const int64_t ne0 = dst->ne[0];
+
+ int id = ggml_cuda_get_device();
+
+ // the main device has a larger memory buffer to hold the results from all GPUs
+ // nrows_dst == nrows of the matrix that the kernel writes into
+ const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
+
+ const int stride_row_x = ne00 / ggml_blck_size(src0->type);
+ const int stride_col_y = src1_padded_row_size / QK8_1;
+
+ mul_mat_vec_q_switch_type(
+ src0_dd_i, src0->type, src1_ddq_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst,
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, stream);
GGML_UNUSED(src1);
GGML_UNUSED(dst);