static __global__ void mul_mat_vec_f(
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
- const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
- const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
+ const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
+ const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
const int row = blockIdx.x;
const int channel_dst = blockIdx.y;
- const int channel_x = ids ? ids[channel_dst] : channel_dst / channel_ratio;
+ const int channel_x = ids ? ids[channel_dst] : fastdiv((uint32_t) channel_dst, channel_ratio);
const int channel_y = ids ? channel_dst % nchannels_y : channel_dst;
const int sample_dst = blockIdx.z;
- const int sample_x = sample_dst / sample_ratio;
+ const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio);
const int sample_y = sample_dst;
const int tid = threadIdx.x;
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
const float2 tmpy = y2[j*stride_col_y2 + col2];
- sumf[j] += tmpx.x*tmpy.x;
- sumf[j] += tmpx.y*tmpy.y;
+ ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
+ ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
}
}
} else if constexpr (std::is_same_v<T, half>) {
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
const float2 tmpy = y2[j*stride_col_y2 + col2];
- sumf[j] += tmpx.x * tmpy.x;
- sumf[j] += tmpx.y * tmpy.y;
+ ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
+ ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
}
}
} else {
#endif // FP16_AVAILABLE
}
} else if constexpr (std::is_same_v<T, nv_bfloat16>) {
+//TODO: add support for ggml_cuda_mad for hip_bfloat162
+#if defined(GGML_USE_HIP)
const int * x2 = (const int *) x;
for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const int tmpx = x2[col2];
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
const float2 tmpy = y2[j*stride_col_y2 + col2];
- sumf[j] += ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]) * tmpy.x;
- sumf[j] += ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]) * tmpy.y;
+ const float tmpx0 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[0]);
+ const float tmpx1 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]);
+ ggml_cuda_mad(sumf[j], tmpx0, tmpy.x);
+ ggml_cuda_mad(sumf[j], tmpx1, tmpy.y);
}
}
+#else
+ const nv_bfloat162 * x2 = (const nv_bfloat162 *) x;
+ for (int col2 = tid; col2 < ncols2; col2 += block_size) {
+ const nv_bfloat162 tmpx = x2[col2];
+#pragma unroll
+ for (int j = 0; j < ncols_dst; ++j) {
+ const float2 tmpy = y2[j*stride_col_y2 + col2];
+ ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
+ ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
+ }
+ }
+#endif
} else {
static_assert(std::is_same_v<T, void>, "unsupported type");
}
GGML_ASSERT(stride_col_y % 2 == 0);
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
- const int64_t channel_ratio = nchannels_dst / nchannels_x;
- const int64_t sample_ratio = nsamples_dst / nsamples_x;
+ const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);
+ const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x);
const int device = ggml_cuda_get_device();
const int warp_size = ggml_cuda_info().devices[device].warp_size;
case 32: {
mul_mat_vec_f<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 64: {
mul_mat_vec_f<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 96: {
mul_mat_vec_f<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 128: {
mul_mat_vec_f<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 160: {
mul_mat_vec_f<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 192: {
mul_mat_vec_f<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 224: {
mul_mat_vec_f<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 256: {
mul_mat_vec_f<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
default: {
GGML_ABORT("fatal error");