using namespace ggml_cuda_mma;
#define MMF_ROWS_PER_BLOCK 32
+#define MMF_ROWS_PER_BLOCK_CDNA 64
+
+static __forceinline__ int64_t mmf_get_max_block_size(int cc) {
+ if (GGML_CUDA_CC_IS_CDNA(cc)) {
+ return 512;
+ } else {
+ return 256;
+ }
+}
+
+static __forceinline__ int mmf_get_padding(int cc) {
+ if (GGML_CUDA_CC_IS_CDNA(cc)) {
+ return 2;
+ } else {
+ return 4;
+ }
+}
+
+static constexpr __device__ int mmf_get_padding() {
+#if defined(AMD_MFMA_AVAILABLE)
+ return 2;
+#else
+ return 4;
+#endif // defined(AMD_MFMA_AVAILABLE)
+}
struct mmf_ids_data {
const int32_t * ids_src_compact = nullptr;
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
-#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
+#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
#if defined(AMD_WMMA_AVAILABLE)
- // Special case for tf32, just dummy mma layout as wmma doesn't support it.
- constexpr bool is_tf32 = std::is_same_v<T, float>;
- constexpr int tile_B_I = is_tf32 ? 8 : 16;
- constexpr int tile_C_J = is_tf32 ? 8 : 16;
- constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout();
- typedef tile<16, 8, T, ab_layout> tile_A;
- typedef tile<tile_B_I, 8, T, ab_layout> tile_B;
- typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;
+ if constexpr (!(std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
+ typedef tile<16, 8, T, get_input_data_layout()> tile_A;
+ typedef tile<16, 8, T, get_input_data_layout()> tile_B;
+ typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
+#elif defined(AMD_MFMA_AVAILABLE)
+ if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else {
+ typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_A;
+ typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_B;
+ typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
#else
#ifdef VOLTA_MMA_AVAILABLE
- if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
+ if constexpr (!std::is_same_v<T, half2> || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A;
typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C;
#else
+ if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
typedef tile<16, 8, T> tile_A;
typedef tile<8, 8, T> tile_B;
typedef tile<16, 8, float> tile_C;
}
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
- constexpr int tile_k_padded = warp_size + 4;
+ constexpr int tile_k_padded = warp_size + mmf_get_padding();
constexpr int ntA = rows_per_block / tile_A::I;
constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
}
float * buf_iw = (float *) compute_base;
- constexpr int kiw = nwarps*rows_per_block + 4;
+ constexpr int kiw = nwarps*rows_per_block + mmf_get_padding();
if (nwarps > 1) {
__syncthreads();
return;
}
- float sum = 0.0f;
- static_assert(rows_per_block == warp_size, "need loop/check");
+ float sum[rows_per_block/warp_size] = {0.0f};
+ static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size.");
#pragma unroll
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
- const int i = i0 + threadIdx.x;
+#pragma unroll
+ for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) {
+ const int i = i0 + i1*warp_size + threadIdx.x;
- sum += buf_iw[j*kiw + i];
+ sum[i1] += buf_iw[j*kiw + i];
+ }
}
if constexpr (!has_ids) {
- dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
+#pragma unroll
+ for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
+ dst[j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
+ }
} else {
const int slot = (j < cols_per_block) ? slot_map[j] : -1;
if (slot >= 0 && (col_base + j) < ncols_dst_total) {
- dst[slot*stride_channel_dst + j*stride_col_dst + row0 + threadIdx.x] = sum;
+#pragma unroll
+ for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
+ dst[slot*stride_channel_dst + j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
+ }
}
}
}
-#ifdef VOLTA_MMA_AVAILABLE
}
-#endif //VOLTA_MMA_AVAILABLE
#else
GGML_UNUSED_VARS(x, y, ids, dst,
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
NO_DEVICE_CODE;
-#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
+#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
}
//This kernel is for larger batch sizes of mul_mat_id
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
const uint3 sis1_fd, const uint3 nch_fd) {
// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
-#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
+#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
#if defined(AMD_WMMA_AVAILABLE)
- // Special case for tf32, just dummy mma layout as wmma doesn't support it.
- constexpr bool is_tf32 = std::is_same_v<T, float>;
- constexpr int tile_B_I = is_tf32 ? 8 : 16;
- constexpr int tile_C_J = is_tf32 ? 8 : 16;
- constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout();
- typedef tile<16, 8, T, ab_layout> tile_A;
- typedef tile<tile_B_I, 8, T, ab_layout> tile_B;
- typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C;
+ if constexpr (!(std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
+ typedef tile<16, 8, T, get_input_data_layout()> tile_A;
+ typedef tile<16, 8, T, get_input_data_layout()> tile_B;
+ typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
+#elif defined(AMD_MFMA_AVAILABLE)
+ if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else {
+ typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_A;
+ typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_B;
+ typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
#else
#ifdef VOLTA_MMA_AVAILABLE
- if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else {
+ if constexpr (!std::is_same_v<T, half2> || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A;
typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C;
#else
+ if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
typedef tile<16, 8, T> tile_A;
typedef tile<8, 8, T> tile_B;
typedef tile<16, 8, float> tile_C;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
- constexpr int tile_k_padded = warp_size + 4;
+ constexpr int tile_k_padded = warp_size + mmf_get_padding();
constexpr int ntA = rows_per_block / tile_A::I;
constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
}
float * buf_iw = (float *) compute_base;
- constexpr int kiw = nwarps*rows_per_block + 4;
+ constexpr int kiw = nwarps*rows_per_block + mmf_get_padding();
if (nwarps > 1) {
__syncthreads();
return;
}
- float sum = 0.0f;
- static_assert(rows_per_block == warp_size, "need loop/check");
+ float sum[rows_per_block/warp_size] = {0.0f};
+ static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size.");
#pragma unroll
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
- const int i = i0 + threadIdx.x;
+#pragma unroll
+ for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) {
+ const int i = i0 + i1*warp_size + threadIdx.x;
- sum += buf_iw[j*kiw + i];
+ sum[i1] += buf_iw[j * kiw + i];
+ }
}
const int global_j = col_base + j;
const int token = (int) qrm.x;
if (token < ncols_dst_total) {
const int slot = (int) qrm.y;
- dst[slot*stride_channel_dst + token*stride_col_dst + row0 + threadIdx.x] = sum;
+#pragma unroll
+ for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
+ dst[slot * stride_channel_dst + token * stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
+ }
}
}
}
-#ifdef VOLTA_MMA_AVAILABLE
}
-#endif // VOLTA_MMA_AVAILABLE
#else
GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
NO_DEVICE_CODE;
-#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE)
+#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
}
-template<typename T, int cols_per_block, int nwarps>
+template<typename T, int rows_per_block, int cols_per_block, int nwarps>
static inline void mul_mat_f_switch_ids(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t ncols_dst, const int64_t nchannels_dst,
const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1);
const uint3 nch_fd = init_fastdiv_values((uint32_t) nchannels_dst);
- mul_mat_f_ids<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
+ mul_mat_f_ids<T, rows_per_block, cols_per_block, nwarps><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
(x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst,
ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
dim3 block_nums_ids = block_nums;
block_nums_ids.y *= col_tiles;
- mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
+ mul_mat_f<T, rows_per_block, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
(x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} else {
- mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
+ mul_mat_f<T, rows_per_block, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
(x, y, ids, dst, ncols_x, cols_per_block, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
}
}
-template <typename T, int cols_per_block>
+template <typename T, int rows_per_block, int cols_per_block>
void mul_mat_f_cuda(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
int64_t nwarps_best = 1;
int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2);
- int64_t max_block_size = 256;
+ int64_t max_block_size = mmf_get_max_block_size(cc);
for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
if (niter < niter_best) {
}
}
- constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
- const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4;
- const int nbytes_cols_per_block_pad = amd_wmma_available(cc) ? tile_B_16::I : tile_B_8::I;
- const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + 4) * 4;
+ const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + mmf_get_padding(cc)) * 4;
+ const int nbytes_cols_per_block_pad = (amd_wmma_available(cc) || amd_mfma_available(cc)) ? tile_B_16::I : tile_B_8::I;
+ const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + mmf_get_padding(cc)) * 4;
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
switch (nwarps_best) {
case 1: {
- mul_mat_f_switch_ids<T, cols_per_block, 1>(
+ mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 1>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 2: {
- mul_mat_f_switch_ids<T, cols_per_block, 2>(
+ mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 2>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 3: {
- mul_mat_f_switch_ids<T, cols_per_block, 3>(
+ mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 3>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 4: {
- mul_mat_f_switch_ids<T, cols_per_block, 4>(
+ mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 4>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 5: {
- mul_mat_f_switch_ids<T, cols_per_block, 5>(
+ mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 5>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 6: {
- mul_mat_f_switch_ids<T, cols_per_block, 6>(
+ mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 6>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 7: {
- mul_mat_f_switch_ids<T, cols_per_block, 7>(
+ mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 7>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data);
} break;
case 8: {
- mul_mat_f_switch_ids<T, cols_per_block, 8>(
+ mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 8>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
GGML_UNUSED_VARS(nchannels_y);
}
-template <typename T>
+template <typename T, int rows_per_block>
static void mul_mat_f_switch_cols_per_block(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
switch (ncols_case) {
case 1: {
- mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ mul_mat_f_cuda<T, rows_per_block, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 2: {
- mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ mul_mat_f_cuda<T, rows_per_block, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 3: {
- mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ mul_mat_f_cuda<T, rows_per_block, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 4: {
- mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ mul_mat_f_cuda<T, rows_per_block, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 5: {
- mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ mul_mat_f_cuda<T, rows_per_block, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 6: {
- mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ mul_mat_f_cuda<T, rows_per_block, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 7: {
- mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ mul_mat_f_cuda<T, rows_per_block, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 8: {
- mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ mul_mat_f_cuda<T, rows_per_block, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 9: {
- mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ mul_mat_f_cuda<T, rows_per_block, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 10: {
- mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ mul_mat_f_cuda<T, rows_per_block, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 11: {
- mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ mul_mat_f_cuda<T, rows_per_block, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 12: {
- mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ mul_mat_f_cuda<T, rows_per_block, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 13: {
- mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ mul_mat_f_cuda<T, rows_per_block, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 14: {
- mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ mul_mat_f_cuda<T, rows_per_block, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 15: {
- mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ mul_mat_f_cuda<T, rows_per_block, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case 16: {
- mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ mul_mat_f_cuda<T, rows_per_block, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
}
}
-#define DECL_MMF_CASE_HELPER(T, ncols_dst) \
- template void mul_mat_f_cuda<T, ncols_dst>( \
+template <typename T>
+static void mul_mat_f_switch_rows_per_block(
+ const int rows_per_block, const T * x, const float * y, const int32_t * ids, float * dst,
+ const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
+ const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
+ const int64_t stride_col_id, const int stride_row_id,
+ const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
+ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
+ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+ cudaStream_t stream, const mmf_ids_data * ids_data) {
+ switch (rows_per_block) {
+ case MMF_ROWS_PER_BLOCK: {
+ mul_mat_f_switch_cols_per_block<T, MMF_ROWS_PER_BLOCK>(
+ x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
+ } break;
+ case MMF_ROWS_PER_BLOCK_CDNA: {
+ mul_mat_f_switch_cols_per_block<T, MMF_ROWS_PER_BLOCK_CDNA>(
+ x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
+ } break;
+ default:
+ GGML_ABORT("unsupported rows_per_block: %i", rows_per_block);
+ }
+}
+
+#define DECL_MMF_CASE_HELPER(T, nrows_dst, ncols_dst) \
+ template void mul_mat_f_cuda<T, nrows_dst, ncols_dst>( \
const T * x, const float * y, const int32_t * ids, float * dst, \
const int64_t ncols_x, const int64_t nrows_x, int64_t ncols_dst_total, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \
const int64_t stride_col_id, const int64_t stride_row_id, \
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \
cudaStream_t stream, const mmf_ids_data * ids_data);
-#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+#if !defined(GGML_USE_MUSA)
#define DECL_MMF_CASE_EXTERN(ncols_dst) \
- extern DECL_MMF_CASE_HELPER(float, ncols_dst) \
- extern DECL_MMF_CASE_HELPER(half2, ncols_dst) \
- extern DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
+ extern DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK, ncols_dst) \
+ extern DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK, ncols_dst) \
+ extern DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK, ncols_dst) \
+ extern DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
+ extern DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
+ extern DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst)
#define DECL_MMF_CASE(ncols_dst) \
- DECL_MMF_CASE_HELPER(float, ncols_dst) \
- DECL_MMF_CASE_HELPER(half2, ncols_dst) \
- DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
+ DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK, ncols_dst) \
+ DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK, ncols_dst) \
+ DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK, ncols_dst) \
+ DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
+ DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
+ DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst)
DECL_MMF_CASE_EXTERN(1);
DECL_MMF_CASE_EXTERN(2);