#include "common.cuh"
#include "mmv.cuh"
-template <typename T, typename type_acc, int block_size>
+template <typename T, typename type_acc, int ncols_dst, int block_size>
static __global__ void mul_mat_vec(
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
- const int64_t ncols2, const int64_t nchannels_y, const int64_t stride_row,
- const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
- const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) {
- const int64_t row = blockIdx.x;
- const int64_t channel_dst = blockIdx.y;
- const int64_t channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio;
- const int64_t channel_y = ids ? channel_dst % nchannels_y : channel_dst;
- const int64_t sample_dst = blockIdx.z;
- const int64_t sample_x = sample_dst / sample_ratio;
- const int64_t sample_y = sample_dst;
- const int tid = threadIdx.x;
+ const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
+ const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
+ const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
+ const int row = blockIdx.x;
+ const int channel_dst = blockIdx.y;
+ const int channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio;
+ const int channel_y = ids ? channel_dst % nchannels_y : channel_dst;
+ const int sample_dst = blockIdx.z;
+ const int sample_x = sample_dst / sample_ratio;
+ const int sample_y = sample_dst;
+ const int tid = threadIdx.x;
+
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
- x += sample_x *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
- y += sample_y *stride_sample_y + channel_y *stride_channel_y;
- dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst;
+ x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
+ y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
+ dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
const float2 * y2 = (const float2 *) y;
__syncthreads();
}
- float sumf = 0.0f;
+ float sumf[ncols_dst] = {0.0f};
if constexpr (std::is_same<T, float>::value) {
const float2 * x2 = (const float2 *) x;
- for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
+ for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const float2 tmpx = x2[col2];
- const float2 tmpy = y2[col2];
- sumf += tmpx.x*tmpy.x;
- sumf += tmpx.y*tmpy.y;
+
+#pragma unroll
+ for (int j = 0; j < ncols_dst; ++j) {
+ const float2 tmpy = y2[j*stride_col_y2 + col2];
+ sumf[j] += tmpx.x*tmpy.x;
+ sumf[j] += tmpx.y*tmpy.y;
+ }
}
} else if constexpr (std::is_same<T, half>::value) {
const half2 * x2 = (const half2 *) x;
if (std::is_same<type_acc, float>::value) {
- for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
+ for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const float2 tmpx = __half22float2(x2[col2]);
- const float2 tmpy = y2[col2];
- sumf += tmpx.x * tmpy.x;
- sumf += tmpx.y * tmpy.y;
+
+#pragma unroll
+ for (int j = 0; j < ncols_dst; ++j) {
+ const float2 tmpy = y2[j*stride_col_y2 + col2];
+ sumf[j] += tmpx.x * tmpy.x;
+ sumf[j] += tmpx.y * tmpy.y;
+ }
}
} else {
#ifdef FP16_AVAILABLE
- half2 sumh2 = make_half2(0.0f, 0.0f);
+ half2 sumh2[ncols_dst] = {{0.0f, 0.0f}};
+
+ for (int col2 = tid; col2 < ncols2; col2 += block_size) {
+ const half2 tmpx = x2[col2];
- for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
- const float2 tmp = y2[col2];
- sumh2 += x2[col2] * make_half2(tmp.x, tmp.y);
+#pragma unroll
+ for (int j = 0; j < ncols_dst; ++j) {
+ const float2 tmpy = y2[j*stride_col_y2 + col2];
+ sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y);
+ }
}
- sumf = __low2float(sumh2) + __high2float(sumh2);
+#pragma unroll
+ for (int j = 0; j < ncols_dst; ++j) {
+ sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]);
+ }
#else
NO_DEVICE_CODE;
#endif // FP16_AVAILABLE
}
} else if constexpr (std::is_same<T, nv_bfloat16>::value) {
const int * x2 = (const int *) x;
- for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
- const int tmpx = x2[col2];
- const float2 tmpy = y2[col2];
- sumf += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
- sumf += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
+ for (int col2 = tid; col2 < ncols2; col2 += block_size) {
+ const int tmpx = x2[col2];
+#pragma unroll
+ for (int j = 0; j < ncols_dst; ++j) {
+ const float2 tmpy = y2[j*stride_col_y2 + col2];
+ sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
+ sumf[j] += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
+ }
}
} else {
static_assert(std::is_same<T, void>::value, "unsupported type");
}
- sumf = warp_reduce_sum<warp_size>(sumf);
+#pragma unroll
+ for (int j = 0; j < ncols_dst; ++j) {
+ sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
- if (block_size > warp_size) {
- buf_iw[tid/warp_size] = sumf;
- __syncthreads();
- if (tid >= warp_size) {
- return;
+ if (block_size > warp_size) {
+ buf_iw[tid/warp_size] = sumf[j];
+ __syncthreads();
+ if (tid < warp_size) {
+ sumf[j] = buf_iw[tid];
+ sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
+ }
+ if (j < ncols_dst) {
+ __syncthreads();
+ }
}
- sumf = buf_iw[tid];
- sumf = warp_reduce_sum<warp_size>(sumf);
}
- if (tid != 0) {
+ if (tid >= ncols_dst) {
return;
}
- dst[row] = sumf;
+ dst[tid*stride_col_dst + row] = sumf[tid];
}
-template <typename T, typename type_acc>
+template <typename T, typename type_acc, int ncols_dst>
static void launch_mul_mat_vec_cuda(
const T * x, const float * y, const int32_t * ids, float * dst,
- const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
+ const int64_t ncols, const int64_t nrows,
+ const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
+ const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream) {
- GGML_ASSERT(ncols % 2 == 0);
- GGML_ASSERT(stride_row % 2 == 0);
+ GGML_ASSERT(ncols % 2 == 0);
+ GGML_ASSERT(stride_row % 2 == 0);
+ GGML_ASSERT(stride_col_y % 2 == 0);
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
const int64_t channel_ratio = nchannels_dst / nchannels_x;
const dim3 block_dims(block_size_best, 1, 1);
switch (block_size_best) {
case 32: {
- mul_mat_vec<T, type_acc, 32><<<block_nums, block_dims, smem, stream>>>
- (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
- stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ mul_mat_vec<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, smem, stream>>>
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 64: {
- mul_mat_vec<T, type_acc, 64><<<block_nums, block_dims, smem, stream>>>
- (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
- stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ mul_mat_vec<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, smem, stream>>>
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 96: {
- mul_mat_vec<T, type_acc, 96><<<block_nums, block_dims, smem, stream>>>
- (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
- stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ mul_mat_vec<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, smem, stream>>>
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 128: {
- mul_mat_vec<T, type_acc, 128><<<block_nums, block_dims, smem, stream>>>
- (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
- stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ mul_mat_vec<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, smem, stream>>>
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 160: {
- mul_mat_vec<T, type_acc, 160><<<block_nums, block_dims, smem, stream>>>
- (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
- stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ mul_mat_vec<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, smem, stream>>>
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 192: {
- mul_mat_vec<T, type_acc, 192><<<block_nums, block_dims, smem, stream>>>
- (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
- stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ mul_mat_vec<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, smem, stream>>>
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 224: {
- mul_mat_vec<T, type_acc, 224><<<block_nums, block_dims, smem, stream>>>
- (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
- stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ mul_mat_vec<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, smem, stream>>>
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 256: {
- mul_mat_vec<T, type_acc, 256><<<block_nums, block_dims, smem, stream>>>
- (x, y, ids, dst, ncols/2, nchannels_y, stride_row, channel_ratio, stride_channel_x, stride_channel_y,
- stride_channel_dst, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ mul_mat_vec<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, smem, stream>>>
+ (x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
default: {
GGML_ABORT("fatal error");
}
}
+template <typename T, typename type_acc>
+static void mul_mat_vec_cuda_switch_ncols_dst(
+ const T * x, const float * y, const int32_t * ids, float * dst,
+ const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
+ const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
+ const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
+ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
+ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+ cudaStream_t stream) {
+ switch (ncols_dst) {
+ case 1:
+ launch_mul_mat_vec_cuda<T, type_acc, 1>
+ (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ break;
+ case 2:
+ launch_mul_mat_vec_cuda<T, type_acc, 2>
+ (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ break;
+ case 3:
+ launch_mul_mat_vec_cuda<T, type_acc, 3>
+ (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ break;
+ case 4:
+ launch_mul_mat_vec_cuda<T, type_acc, 4>
+ (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ break;
+ case 5:
+ launch_mul_mat_vec_cuda<T, type_acc, 5>
+ (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ break;
+ case 6:
+ launch_mul_mat_vec_cuda<T, type_acc, 6>
+ (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ break;
+ case 7:
+ launch_mul_mat_vec_cuda<T, type_acc, 7>
+ (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ break;
+ case 8:
+ launch_mul_mat_vec_cuda<T, type_acc, 8>
+ (x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+ stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ break;
+ default:
+ GGML_ABORT("fatal error");
+ break;
+ }
+}
+
template<typename T>
static void mul_mat_vec_cuda(
const T * x, const float * y, const int32_t * ids, float * dst,
- const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
+ const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
+ const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst,
+ const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
enum ggml_prec prec, cudaStream_t stream) {
if constexpr(std::is_same<T, half>::value) {
if (prec == GGML_PREC_DEFAULT) {
- launch_mul_mat_vec_cuda<T, half>
- (x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+ mul_mat_vec_cuda_switch_ncols_dst<T, half>
+ (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
return;
}
}
- launch_mul_mat_vec_cuda<T, float>
- (x, y, ids, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
+ mul_mat_vec_cuda_switch_ncols_dst<T, float>
+ (x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
}
const int64_t stride_channel_dst = ids ? s1 : s2;
const int64_t stride_channel_y = ids ? s11 : s12;
- GGML_ASSERT(ncols_dst == 1);
+ GGML_ASSERT(!ids || ncols_dst == 1);
switch (src0->type) {
case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0->data;
- mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
+ mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, prec, ctx.stream());
} break;
case GGML_TYPE_F16: {
const half * src0_d = (const half *) src0->data;
- mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
+ mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, prec, ctx.stream());
} break;
case GGML_TYPE_BF16: {
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
- mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, s01,
+ mul_mat_vec_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, prec, ctx.stream());
} break;
GGML_ASSERT(dst->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0];
+ const int64_t ne10 = src1->ne[0];
+ const int64_t ne0 = dst->ne[0];
const int64_t row_diff = row_high - row_low;
- GGML_ASSERT(src1_ncols == 1);
-
- const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+ const int id = ggml_cuda_get_device();
+ const int cc = ggml_cuda_info().devices[id].cc;
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
// ggml_cuda_op provides single, contiguous matrices
const int64_t stride_row = ne00;
+ const int64_t stride_col_y = ne10;
+ const int64_t stride_col_dst = id == ctx.device ? ne0 : row_diff; // main device has larger memory buffer
const int64_t nchannels_x = 1;
const int64_t nchannels_y = 1;
const int64_t nchannels_dst = 1;
switch (src0->type) {
case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0_dd_i;
- mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
+ mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
} break;
case GGML_TYPE_F16: {
const half * src0_d = (const half *) src0_dd_i;
- mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
+ mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
} break;
case GGML_TYPE_BF16: {
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
- mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, stride_row,
+ mul_mat_vec_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
} break;
GGML_UNUSED(src1_ncols);
GGML_UNUSED(src1_padded_row_size);
}
+
+bool ggml_cuda_should_use_mmv(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) {
+ if (src0_ne[0] % 2 != 0) {
+ return false;
+ }
+ switch (type) {
+ case GGML_TYPE_F32:
+ if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
+ if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
+ return ne11 <= 8;
+ }
+ if (cc >= GGML_CUDA_CC_TURING) {
+ return ne11 <= 4;
+ }
+ return ne11 <= 3;
+ }
+ return ne11 <= 8;
+ case GGML_TYPE_F16:
+ if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
+ const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
+ if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
+ return src0_small && ne11 <= 4;
+ }
+ if (fp16_mma_hardware_available(cc)) {
+ return src0_small && ne11 <= 3;
+ }
+ return ne11 <= 8;
+ }
+ return ne11 <= 8;
+ case GGML_TYPE_BF16:
+ if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
+ const bool src0_small = (src0_ne[1] <= 512 || src0_ne[2]*src0_ne[3] == 1);
+ if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
+ return src0_small && ne11 <= 4;
+ }
+ if (bf16_mma_hardware_available(cc)) {
+ return src0_small && ne11 <= 3;
+ }
+ return ne11 <= 8;
+ }
+ return ne11 <= 8;
+ default:
+ return false;
+ }
+}