#include "common.cuh"
#include "mmv.cuh"
-template <typename type_acc, int block_size>
+template <typename T, typename type_acc, int block_size>
static __global__ void mul_mat_vec(
- const half * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
+ const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) {
const int64_t row = blockIdx.x;
const int64_t channel = blockIdx.z;
y += channel *stride_channel_y;
dst += channel *stride_channel_dst;
- const half2 * x2 = (const half2 *) x;
const float2 * y2 = (const float2 *) y;
extern __shared__ char data_mmv[];
float sumf;
- if (std::is_same<type_acc, float>::value) {
- sumf = 0.0f;
+ if constexpr (std::is_same<T, half>::value) {
+ const half2 * x2 = (const half2 *) x;
- for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
- const float2 tmpx = __half22float2(x2[col2]);
- const float2 tmpy = y2[col2];
- sumf += tmpx.x * tmpy.x;
- sumf += tmpx.y * tmpy.y;
- }
- } else {
+ if (std::is_same<type_acc, float>::value) {
+ sumf = 0.0f;
+
+ for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
+ const float2 tmpx = __half22float2(x2[col2]);
+ const float2 tmpy = y2[col2];
+ sumf += tmpx.x * tmpy.x;
+ sumf += tmpx.y * tmpy.y;
+ }
+ } else {
#ifdef FP16_AVAILABLE
- half2 sumh2 = make_half2(0.0f, 0.0f);
+ half2 sumh2 = make_half2(0.0f, 0.0f);
- for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
- const float2 tmp = y2[col2];
- sumh2 += x2[col2] * make_half2(tmp.x, tmp.y);
- }
+ for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
+ const float2 tmp = y2[col2];
+ sumh2 += x2[col2] * make_half2(tmp.x, tmp.y);
+ }
- sumf = __low2float(sumh2) + __high2float(sumh2);
+ sumf = __low2float(sumh2) + __high2float(sumh2);
#else
- NO_DEVICE_CODE;
+ NO_DEVICE_CODE;
#endif // FP16_AVAILABLE
+ }
+ } else if constexpr (std::is_same<T, nv_bfloat16>::value) {
+ const int * x2 = (const int *) x;
+ sumf = 0.0f;
+
+ for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) {
+ const int tmpx = x2[col2];
+ const float2 tmpy = y2[col2];
+ sumf += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
+ sumf += float(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
+ }
+ } else {
+ static_assert(std::is_same<T, void>::value, "unsupported type");
}
sumf = warp_reduce_sum(sumf);
dst[row] = sumf;
}
-template <typename type_acc>
+template <typename T, typename type_acc>
static void launch_mul_mat_vec_cuda(
- const half * x, const float * y, float * dst,
+ const T * x, const float * y, float * dst,
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
cudaStream_t stream) {
const dim3 block_dims(block_size_best, 1, 1);
switch (block_size_best) {
case 32: {
- mul_mat_vec<type_acc, 32><<<block_nums, block_dims, smem, stream>>>
+ mul_mat_vec<T, type_acc, 32><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
} break;
case 64: {
- mul_mat_vec<type_acc, 64><<<block_nums, block_dims, smem, stream>>>
+ mul_mat_vec<T, type_acc, 64><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
} break;
case 96: {
- mul_mat_vec<type_acc, 96><<<block_nums, block_dims, smem, stream>>>
+ mul_mat_vec<T, type_acc, 96><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
} break;
case 128: {
- mul_mat_vec<type_acc, 128><<<block_nums, block_dims, smem, stream>>>
+ mul_mat_vec<T, type_acc, 128><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
} break;
case 160: {
- mul_mat_vec<type_acc, 160><<<block_nums, block_dims, smem, stream>>>
+ mul_mat_vec<T, type_acc, 160><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
} break;
case 192: {
- mul_mat_vec<type_acc, 192><<<block_nums, block_dims, smem, stream>>>
+ mul_mat_vec<T, type_acc, 192><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
} break;
case 224: {
- mul_mat_vec<type_acc, 224><<<block_nums, block_dims, smem, stream>>>
+ mul_mat_vec<T, type_acc, 224><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
} break;
case 256: {
- mul_mat_vec<type_acc, 256><<<block_nums, block_dims, smem, stream>>>
+ mul_mat_vec<T, type_acc, 256><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
} break;
default: {
}
}
+template<typename T>
static void mul_mat_vec_cuda(
- const half * x, const float * y, float * dst,
+ const T * x, const float * y, float * dst,
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
enum ggml_prec prec, cudaStream_t stream) {
switch (prec) {
case GGML_PREC_DEFAULT: {
- launch_mul_mat_vec_cuda<half>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
+ launch_mul_mat_vec_cuda<T, half>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
stride_channel_x, stride_channel_y, stride_channel_dst, stream);
} break;
case GGML_PREC_F32: {
- launch_mul_mat_vec_cuda<float>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
+ launch_mul_mat_vec_cuda<T, float>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
stride_channel_x, stride_channel_y, stride_channel_dst, stream);
} break;
}
}
void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
- const half * src0_d = (const half *) src0->data;
const float * src1_d = (const float *) src1->data;
float * dst_d = (float *) dst->data;
const int64_t channel_stride_y = src1->nb[2] / ggml_type_size(src1->type);
const int64_t channel_stride_dst = dst->nb[2] / ggml_type_size( dst->type);
- mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12, channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
+ switch (src0->type) {
+ case GGML_TYPE_F16: {
+ const half * src0_d = (const half *) src0->data;
+ mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
+ channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
+ } break;
+ case GGML_TYPE_BF16: {
+ const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
+ mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
+ channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
+ } break;
+ default:
+ GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
+ }
}
void ggml_cuda_op_mul_mat_vec(
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream) {
- GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
const int64_t channel_stride_y = 0;
const int64_t channel_stride_dst = 0;
- mul_mat_vec_cuda((const half *) src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
- nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
+ switch (src0->type) {
+ case GGML_TYPE_F16: {
+ const half * src0_d = (const half *) src0_dd_i;
+ mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
+ nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
+ } break;
+ case GGML_TYPE_BF16: {
+ const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
+ mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
+ nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
+ } break;
+ default:
+ GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
+ }
GGML_UNUSED(ctx);
GGML_UNUSED(src1);