#include "ggml.h"
-#include "common.cuh"
-#include "mma.cuh"
#include "mmf.cuh"
-using namespace ggml_cuda_mma;
-
-#define MMF_ROWS_PER_BLOCK 32
-
-template <typename T, int rows_per_block, int cols_per_block, int nwarps>
-__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
-static __global__ void mul_mat_f(
- const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
- const int ncols, const int nchannels_y, const int stride_row, const int stride_col_y, const int stride_col_dst,
- const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
- const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
-#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
- typedef tile<16, 8, T> tile_A;
- typedef tile< 8, 8, T> tile_B;
- typedef tile<16, 8, float> tile_C;
-
- constexpr int warp_size = ggml_cuda_get_physical_warp_size();
- constexpr int tile_k_padded = warp_size + 4;
- constexpr int ntA = rows_per_block / tile_A::I;
- constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
-
- const int row0 = blockIdx.x * rows_per_block;
- const int channel_dst = blockIdx.y;
- const int channel_x = channel_dst / channel_ratio;
- const int channel_y = channel_dst;
- const int sample_dst = blockIdx.z;
- const int sample_x = sample_dst / sample_ratio;
- const int sample_y = sample_dst;
-
- x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row ;
- y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
- dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
-
- const float2 * y2 = (const float2 *) y;
-
- extern __shared__ char data_mmv[];
-
- tile_C C[ntA][ntB];
-
- T * tile_xy = (T *) data_mmv + threadIdx.y*(tile_A::I * tile_k_padded);
-
- for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
- tile_A A[ntA][warp_size / tile_A::J];
-#pragma unroll
- for (int itA = 0; itA < ntA; ++itA) {
-#pragma unroll
- for (int i = 0; i < tile_A::I; ++i) {
- tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col];
- }
-#pragma unroll
- for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
- load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
- }
- }
-
-#pragma unroll
- for (int itB = 0; itB < ntB; ++itB) {
- if constexpr (std::is_same_v<T, float>) {
-#pragma unroll
- for (int j0 = 0; j0 < tile_B::I; ++j0) {
- const int j = j0 + itB*tile_B::I;
-
- tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
- }
- } else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
-#pragma unroll
- for (int j0 = 0; j0 < tile_B::I; ++j0) {
- const int j = j0 + itB*tile_B::I;
-
- const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
- tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
- }
- } else {
- static_assert(std::is_same_v<T, void>, "unsupported type");
- }
-#pragma unroll
- for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
- tile_B B;
- load_ldmatrix(B, tile_xy + k0, tile_k_padded);
-#pragma unroll
- for (int itA = 0; itA < ntA; ++itA) {
- mma(C[itA][itB], A[itA][k0/tile_B::J], B);
- }
- }
- }
- }
-
- float * buf_iw = (float *) data_mmv;
- constexpr int kiw = nwarps*rows_per_block + 4;
-
- if (nwarps > 1) {
- __syncthreads();
- }
-#pragma unroll
- for (int itB = 0; itB < ntB; ++itB) {
-#pragma unroll
- for (int itA = 0; itA < ntA; ++itA) {
-#pragma unroll
- for (int l = 0; l < tile_C::ne; ++l) {
- const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
- const int j = itB*tile_C::J + tile_C::get_j(l);
- buf_iw[j*kiw + i] = C[itA][itB].x[l];
- }
- }
- }
-
- if (nwarps > 1) {
- __syncthreads();
- }
-
-#pragma unroll
- for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
- const int j = j0 + threadIdx.y;
-
- if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
- return;
- }
-
- float sum = 0.0f;
- static_assert(rows_per_block == warp_size, "need loop/check");
-#pragma unroll
- for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
- const int i = i0 + threadIdx.x;
-
- sum += buf_iw[j*kiw + i];
- }
- dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
- }
-#else
- GGML_UNUSED_VARS(x, y, ids, dst,
- ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
- NO_DEVICE_CODE;
-#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
-}
-
-template <typename T, int cols_per_block>
-static void mul_mat_f_cuda(
- const T * x, const float * y, const int32_t * ids, float * dst,
- const int64_t ncols_x, const int64_t nrows_x,
- const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
- const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
- const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
- const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
- cudaStream_t stream) {
- typedef tile<16, 8, T> tile_A;
- typedef tile< 8, 8, T> tile_B;
-
- GGML_ASSERT(!ids && "mul_mat_id not implemented");
-
- GGML_ASSERT(ncols_x % 2 == 0);
- GGML_ASSERT(stride_row % 2 == 0);
- GGML_ASSERT(stride_col_y % 2 == 0);
- GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
- GGML_ASSERT( nsamples_dst % nsamples_x == 0);
- const int64_t channel_ratio = nchannels_dst / nchannels_x;
- const int64_t sample_ratio = nsamples_dst / nsamples_x;
-
- const int device = ggml_cuda_get_device();
- const int warp_size = ggml_cuda_info().devices[device].warp_size;
-
- int64_t nwarps_best = 1;
- int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2);
- int64_t max_block_size = 256;
- for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
- const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
- if (niter < niter_best) {
- niter_best = niter;
- nwarps_best = nwarps;
- }
- }
-
- constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
- const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4;
- const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
- const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
- const dim3 block_nums(nrows_x/rows_per_block, nchannels_dst, nsamples_dst);
- const dim3 block_dims(warp_size, nwarps_best, 1);
- switch (nwarps_best) {
- case 1: {
- mul_mat_f<T, rows_per_block, cols_per_block, 1><<<block_nums, block_dims, nbytes_shared, stream>>>
- (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
- } break;
- case 2: {
- mul_mat_f<T, rows_per_block, cols_per_block, 2><<<block_nums, block_dims, nbytes_shared, stream>>>
- (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
- } break;
- case 3: {
- mul_mat_f<T, rows_per_block, cols_per_block, 3><<<block_nums, block_dims, nbytes_shared, stream>>>
- (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
- } break;
- case 4: {
- mul_mat_f<T, rows_per_block, cols_per_block, 4><<<block_nums, block_dims, nbytes_shared, stream>>>
- (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
- } break;
- case 5: {
- mul_mat_f<T, rows_per_block, cols_per_block, 5><<<block_nums, block_dims, nbytes_shared, stream>>>
- (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
- } break;
- case 6: {
- mul_mat_f<T, rows_per_block, cols_per_block, 6><<<block_nums, block_dims, nbytes_shared, stream>>>
- (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
- } break;
- case 7: {
- mul_mat_f<T, rows_per_block, cols_per_block, 7><<<block_nums, block_dims, nbytes_shared, stream>>>
- (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
- } break;
- case 8: {
- mul_mat_f<T, rows_per_block, cols_per_block, 8><<<block_nums, block_dims, nbytes_shared, stream>>>
- (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
- } break;
- default: {
- GGML_ABORT("fatal error");
- } break;
- }
-}
-
-template <typename T>
-static void mul_mat_f_switch_cols_per_block(
- const T * x, const float * y, const int32_t * ids, float * dst,
- const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
- const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
- const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
- const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
- const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
- cudaStream_t stream) {
- switch (ncols_dst) {
- case 1: {
- mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
- nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
- } break;
- case 2: {
- mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
- nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
- } break;
- case 3: {
- mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
- nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
- } break;
- case 4: {
- mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
- nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
- } break;
- case 5: {
- mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
- nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
- } break;
- case 6: {
- mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
- nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
- } break;
- case 7: {
- mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
- nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
- } break;
- case 8: {
- mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
- nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
- } break;
- case 9: {
- mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
- nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
- } break;
- case 10: {
- mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
- nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
- } break;
- case 11: {
- mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
- nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
- } break;
- case 12: {
- mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
- nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
- } break;
- case 13: {
- mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
- nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
- } break;
- case 14: {
- mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
- nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
- } break;
- case 15: {
- mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
- nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
- } break;
- case 16: {
- mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
- nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
- nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
- } break;
- default: {
- GGML_ABORT("fatal error");
- } break;
- }
-}
-
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
GGML_ASSERT( src1->type == GGML_TYPE_F32);
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
GGML_TENSOR_BINARY_OP_LOCALS;
const size_t ts_src0 = ggml_type_size(src0->type);
const int64_t s13 = src1->nb[3] / ts_src1;
const int64_t s3 = dst->nb[3] / ts_dst;
+ const int64_t ids_s0 = ids ? ids->nb[0] / ggml_type_size(ids->type) : 0;
+ const int64_t ids_s1 = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
+
// For MUL_MAT_ID the memory layout is different than for MUL_MAT:
const int64_t ncols_dst = ids ? ne2 : ne1;
- const int64_t nchannels_y = ids ? ne11 : ne12;
- const int64_t nchannels_dst = ids ? ne1 : ne2;
- const int64_t stride_channel_dst = ids ? s1 : s2;
- const int64_t stride_channel_y = ids ? s11 : s12;
+ const int64_t nchannels_dst = ids ? ne1 : ne2;
+
+ const int64_t stride_col_dst = ids ? s2 : s1;
+ const int64_t stride_col_y = ids ? s12 : s11;
+ const int64_t stride_channel_dst = ids ? s1 : s2;
- GGML_ASSERT(!ids || ncols_dst == 1);
+ int64_t stride_channel_y = ids ? s11 : s12;
+ int64_t nchannels_y = ids ? ne11 : ne12;
+
+ //mul_mat_id: handle broadcast
+ if (ids && nchannels_y == 1) {
+ stride_channel_y = 0;
+ nchannels_y = ids->ne[0];
+ }
switch (src0->type) {
case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0->data;
constexpr int vals_per_T = 1;
mul_mat_f_switch_cols_per_block(
- src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
- ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
- ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
+ src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
+ ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
+ ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
} break;
case GGML_TYPE_F16: {
const half2 * src0_d = (const half2 *) src0->data;
constexpr int vals_per_T = 2;
mul_mat_f_switch_cols_per_block(
- src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
- ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
- ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
+ src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
+ ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
+ ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
} break;
case GGML_TYPE_BF16: {
const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
constexpr int vals_per_T = 2;
mul_mat_f_switch_cols_per_block(
- src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
- ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
- ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
+ src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
+ ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
+ ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
} break;
default:
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
}
}
-bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, int64_t ne11) {
+bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, const int src1_ncols) {
+
+ if (ggml_is_quantized(type)) {
+ return false;
+ }
+
if (src0_ne[0] % (warp_size * (4/ggml_type_size(type))) != 0) {
return false;
}
if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) {
return false;
}
- if (ne11 > 16) {
+ if (src1_ncols > 16) {
return false;
}
+
switch (type) {
case GGML_TYPE_F32:
return ampere_mma_available(cc);
+#pragma once
+
+#include "mma.cuh"
#include "common.cuh"
+using namespace ggml_cuda_mma;
+
+#define MMF_ROWS_PER_BLOCK 32
+
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
-bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, int64_t ne11);
+bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols);
+
+template <typename T, int rows_per_block, int cols_per_block, int nwarps, bool has_ids>
+__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
+static __global__ void mul_mat_f(
+ const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
+ const int ncols, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
+ const int stride_col_id, const int stride_row_id,
+ const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
+ const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+ typedef tile<16, 8, T> tile_A;
+ typedef tile< 8, 8, T> tile_B;
+ typedef tile<16, 8, float> tile_C;
+
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+ constexpr int tile_k_padded = warp_size + 4;
+ constexpr int ntA = rows_per_block / tile_A::I;
+ constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
+
+ const int row0 = blockIdx.x * rows_per_block;
+
+ const int expert_idx = has_ids ? blockIdx.y : 0;
+ const int channel_dst = has_ids ? 0 : blockIdx.y;
+
+ const int channel_x = has_ids ? expert_idx : (channel_dst / channel_ratio);
+ const int channel_y = channel_dst;
+ const int sample_dst = blockIdx.z;
+ const int sample_x = sample_dst / sample_ratio;
+ const int sample_y = sample_dst;
+
+ x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row ;
+ y += int64_t(sample_y) *stride_sample_y + (has_ids ? 0 : channel_y *stride_channel_y);
+ dst += int64_t(sample_dst)*stride_sample_dst + (has_ids ? 0 : channel_dst*stride_channel_dst);
+
+ const float2 * y2 = (const float2 *) y;
+
+ extern __shared__ char data_mmv[];
+
+ char * shmem_base = data_mmv;
+ int * slot_map = (int *) shmem_base;
+ char * compute_base = has_ids ? (shmem_base + GGML_PAD(cols_per_block, 16) * sizeof(int)) : shmem_base;
+
+ tile_C C[ntA][ntB];
+
+ T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);
+
+ if constexpr (has_ids) {
+ __shared__ int has_any;
+ if (threadIdx.y == 0) {
+ int local_has_any = 0;
+ for (int j = threadIdx.x; j < cols_per_block; j += warp_size) {
+ int slot = -1;
+ for (int k = 0; k < nchannels_dst; ++k) {
+ const int idv = ids[j*stride_row_id + k*stride_col_id];
+ if (idv == expert_idx) {
+ slot = k;
+ break;
+ }
+ }
+ if (j < cols_per_block) {
+ local_has_any |= (slot >= 0);
+ slot_map[j] = slot;
+ }
+ }
+ has_any = warp_reduce_any(local_has_any);
+ }
+ __syncthreads();
+ if (has_any == 0) {
+ return;
+ }
+ }
+
+ for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
+ tile_A A[ntA][warp_size / tile_A::J];
+#pragma unroll
+ for (int itA = 0; itA < ntA; ++itA) {
+#pragma unroll
+ for (int i = 0; i < tile_A::I; ++i) {
+ tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col];
+ }
+#pragma unroll
+ for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
+ load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
+ }
+ }
+
+#pragma unroll
+ for (int itB = 0; itB < ntB; ++itB) {
+ if constexpr (std::is_same_v<T, float>) {
+#pragma unroll
+ for (int j0 = 0; j0 < tile_B::I; ++j0) {
+ const int j = j0 + itB*tile_B::I;
+
+ if constexpr (!has_ids) {
+ tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
+ } else {
+ float val = 0.0f;
+ if (j < cols_per_block) {
+ const int slot = slot_map[j];
+ if (slot >= 0) {
+ val = y[slot*stride_channel_y + j*stride_col_y + col];
+ }
+ }
+ tile_xy[j0*tile_k_padded + threadIdx.x] = val;
+ }
+ }
+ } else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
+#pragma unroll
+ for (int j0 = 0; j0 < tile_B::I; ++j0) {
+ const int j = j0 + itB*tile_B::I;
+
+ if constexpr (!has_ids) {
+ const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
+ tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
+ } else {
+ float2 tmp = make_float2(0.0f, 0.0f);
+ if (j < cols_per_block) {
+ const int slot = slot_map[j];
+ if (slot >= 0) {
+ const float2 * y2_slot = (const float2 *)(y + slot*stride_channel_y);
+ tmp = y2_slot[j*stride_col_y + col];
+ }
+ }
+ tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
+ }
+ }
+ } else {
+ static_assert(std::is_same_v<T, void>, "unsupported type");
+ }
+#pragma unroll
+ for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
+ tile_B B;
+ load_ldmatrix(B, tile_xy + k0, tile_k_padded);
+#pragma unroll
+ for (int itA = 0; itA < ntA; ++itA) {
+ mma(C[itA][itB], A[itA][k0/tile_B::J], B);
+ }
+ }
+ }
+ }
+
+ float * buf_iw = (float *) compute_base;
+ constexpr int kiw = nwarps*rows_per_block + 4;
+
+ if (nwarps > 1) {
+ __syncthreads();
+ }
+#pragma unroll
+ for (int itB = 0; itB < ntB; ++itB) {
+#pragma unroll
+ for (int itA = 0; itA < ntA; ++itA) {
+#pragma unroll
+ for (int l = 0; l < tile_C::ne; ++l) {
+ const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
+ const int j = itB*tile_C::J + tile_C::get_j(l);
+ buf_iw[j*kiw + i] = C[itA][itB].x[l];
+ }
+ }
+ }
+
+ if (nwarps > 1) {
+ __syncthreads();
+ }
+
+#pragma unroll
+ for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
+ const int j = j0 + threadIdx.y;
+
+ if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
+ return;
+ }
+
+ float sum = 0.0f;
+ static_assert(rows_per_block == warp_size, "need loop/check");
+#pragma unroll
+ for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
+ const int i = i0 + threadIdx.x;
+
+ sum += buf_iw[j*kiw + i];
+ }
+
+ if constexpr (!has_ids) {
+ dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
+ } else {
+ const int slot = (j < cols_per_block) ? slot_map[j] : -1;
+ if (slot >= 0) {
+ dst[slot*stride_channel_dst + j*stride_col_dst + row0 + threadIdx.x] = sum;
+ }
+ }
+ }
+#else
+ GGML_UNUSED_VARS(x, y, ids, dst,
+ ncols, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id,
+ channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ NO_DEVICE_CODE;
+#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+}
+
+template<typename T, int cols_per_block, int nwarps>
+static inline void mul_mat_f_switch_ids(
+ const T * x, const float * y, const int32_t * ids, float * dst,
+ const int64_t ncols_x, const int64_t nchannels_dst,
+ const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
+ const int64_t stride_col_id, const int64_t stride_row_id,
+ const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
+ const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+ const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) {
+ if (ids) {
+ mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums, block_dims, nbytes_shared_total, stream>>>
+ (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ } else {
+ mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
+ (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+ }
+}
+
+template <typename T, int cols_per_block>
+void mul_mat_f_cuda(
+ const T * x, const float * y, const int32_t * ids, float * dst,
+ const int64_t ncols_x, const int64_t nrows_x, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
+ const int64_t stride_col_id, const int64_t stride_row_id,
+ const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
+ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
+ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+ cudaStream_t stream) {
+ typedef tile<16, 8, T> tile_A;
+ typedef tile< 8, 8, T> tile_B;
+
+ GGML_ASSERT(ncols_x % 2 == 0);
+ GGML_ASSERT(stride_row % 2 == 0);
+ GGML_ASSERT(stride_col_y % 2 == 0);
+ GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
+ GGML_ASSERT( nsamples_dst % nsamples_x == 0);
+ const int64_t channel_ratio = nchannels_dst / nchannels_x;
+ const int64_t sample_ratio = nsamples_dst / nsamples_x;
+
+ const int device = ggml_cuda_get_device();
+ const int warp_size = ggml_cuda_info().devices[device].warp_size;
+
+ int64_t nwarps_best = 1;
+ int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2);
+ int64_t max_block_size = 256;
+ for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
+ const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
+ if (niter < niter_best) {
+ niter_best = niter;
+ nwarps_best = nwarps;
+ }
+ }
+
+ constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
+ const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4;
+ const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
+ const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
+ const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
+ const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
+ const int64_t grid_y = ids ? nchannels_x : nchannels_dst; // per expert when ids present
+
+ const dim3 block_nums(nrows_x/rows_per_block, grid_y, nsamples_dst);
+ const dim3 block_dims(warp_size, nwarps_best, 1);
+
+ switch (nwarps_best) {
+ case 1: {
+ mul_mat_f_switch_ids<T, cols_per_block, 1>(
+ x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+ } break;
+ case 2: {
+ mul_mat_f_switch_ids<T, cols_per_block, 2>(
+ x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+ } break;
+ case 3: {
+ mul_mat_f_switch_ids<T, cols_per_block, 3>(
+ x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+ } break;
+ case 4: {
+ mul_mat_f_switch_ids<T, cols_per_block, 4>(
+ x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+ } break;
+ case 5: {
+ mul_mat_f_switch_ids<T, cols_per_block, 5>(
+ x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+ } break;
+ case 6: {
+ mul_mat_f_switch_ids<T, cols_per_block, 6>(
+ x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+ } break;
+ case 7: {
+ mul_mat_f_switch_ids<T, cols_per_block, 7>(
+ x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+ } break;
+ case 8: {
+ mul_mat_f_switch_ids<T, cols_per_block, 8>(
+ x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+ sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+ } break;
+ default: {
+ GGML_ABORT("fatal error");
+ } break;
+ }
+
+ GGML_UNUSED_VARS(nchannels_y);
+}
+
+template <typename T>
+static void mul_mat_f_switch_cols_per_block(
+ const T * x, const float * y, const int32_t * ids, float * dst,
+ const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
+ const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
+ const int64_t stride_col_id, const int stride_row_id,
+ const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
+ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
+ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+ cudaStream_t stream) {
+ switch (ncols_dst) {
+ case 1: {
+ mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ } break;
+ case 2: {
+ mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ } break;
+ case 3: {
+ mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ } break;
+ case 4: {
+ mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ } break;
+ case 5: {
+ mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ } break;
+ case 6: {
+ mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ } break;
+ case 7: {
+ mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ } break;
+ case 8: {
+ mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ } break;
+ case 9: {
+ mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ } break;
+ case 10: {
+ mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ } break;
+ case 11: {
+ mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ } break;
+ case 12: {
+ mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ } break;
+ case 13: {
+ mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ } break;
+ case 14: {
+ mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ } break;
+ case 15: {
+ mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ } break;
+ case 16: {
+ mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+ stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+ } break;
+ default: {
+ GGML_ABORT("fatal error");
+ } break;
+ }
+}
+
+#define DECL_MMF_CASE_HELPER(T, ncols_dst) \
+ template void mul_mat_f_cuda<T, ncols_dst>( \
+ const T * x, const float * y, const int32_t * ids, float * dst, \
+ const int64_t ncols_x, const int64_t nrows_x, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \
+ const int64_t stride_col_id, const int64_t stride_row_id, \
+ const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \
+ const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\
+ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \
+ cudaStream_t stream);
+
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+#define DECL_MMF_CASE_EXTERN(ncols_dst) \
+ extern DECL_MMF_CASE_HELPER(float, ncols_dst) \
+ extern DECL_MMF_CASE_HELPER(half2, ncols_dst) \
+ extern DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
+
+#define DECL_MMF_CASE(ncols_dst) \
+ DECL_MMF_CASE_HELPER(float, ncols_dst) \
+ DECL_MMF_CASE_HELPER(half2, ncols_dst) \
+ DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
+
+DECL_MMF_CASE_EXTERN(1);
+DECL_MMF_CASE_EXTERN(2);
+DECL_MMF_CASE_EXTERN(3);
+DECL_MMF_CASE_EXTERN(4);
+DECL_MMF_CASE_EXTERN(5);
+DECL_MMF_CASE_EXTERN(6);
+DECL_MMF_CASE_EXTERN(7);
+DECL_MMF_CASE_EXTERN(8);
+DECL_MMF_CASE_EXTERN(9);
+DECL_MMF_CASE_EXTERN(10);
+DECL_MMF_CASE_EXTERN(11);
+DECL_MMF_CASE_EXTERN(12);
+DECL_MMF_CASE_EXTERN(13);
+DECL_MMF_CASE_EXTERN(14);
+DECL_MMF_CASE_EXTERN(15);
+DECL_MMF_CASE_EXTERN(16);
+#else
+#define DECL_MMF_CASE(ncols_dst)
+#endif