From: Aman Gupta Date: Tue, 9 Sep 2025 06:38:02 +0000 (+0800) Subject: CUDA: Add mul_mat_id support for the mmf kernel (llama/15767) X-Git-Tag: v0.9.1~58 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=0ac013c9d9d05c7be80570f786fbd445dd95a406;p=pkg%2Fggml%2Fsources%2Fggml CUDA: Add mul_mat_id support for the mmf kernel (llama/15767) * CUDA: Add mul_mat_id support the mmf Add support for mul_mat_id for bs < 16 * Review: use warp_size, fix should_use_mmf condition * Launch one block per expert, stride along n_expert_used * templatize mul_mat_id * Pad shmem to 16 bytes, add helper function mul_mat_f_switch_ids * Reduce compile times by dividing mmf into f16, bf16 and f32 variants * Divide mmf by ncols_dst * Add missing files * Fix MUSA/HIP builds --- diff --git a/src/ggml-cuda/CMakeLists.txt b/src/ggml-cuda/CMakeLists.txt index 90610af5..bdcefe7b 100644 --- a/src/ggml-cuda/CMakeLists.txt +++ b/src/ggml-cuda/CMakeLists.txt @@ -48,6 +48,8 @@ if (CUDAToolkit_FOUND) list(APPEND GGML_SOURCES_CUDA ${SRCS}) file(GLOB SRCS "template-instances/mmq*.cu") list(APPEND GGML_SOURCES_CUDA ${SRCS}) + file(GLOB SRCS "template-instances/mmf*.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) if (GGML_CUDA_FA_ALL_QUANTS) file(GLOB SRCS "template-instances/fattn-vec*.cu") diff --git a/src/ggml-cuda/ggml-cuda.cu b/src/ggml-cuda/ggml-cuda.cu index 95170ae1..0f68d685 100644 --- a/src/ggml-cuda/ggml-cuda.cu +++ b/src/ggml-cuda/ggml-cuda.cu @@ -2109,6 +2109,11 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * ggml_cuda_mul_mat_q(ctx, src0, src1, ids, dst); return; } + + if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src1->ne[2])) { + ggml_cuda_mul_mat_f(ctx, src0, src1, ids, dst); + return; + } } cudaStream_t stream = ctx.stream(); diff --git a/src/ggml-cuda/mma.cuh b/src/ggml-cuda/mma.cuh index 667deb9c..c1f24243 100644 --- a/src/ggml-cuda/mma.cuh +++ b/src/ggml-cuda/mma.cuh @@ -1,3 +1,4 @@ +#pragma once // This file contains primitives that expose the tensor core PTX instructions for CUDA code. // The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout. // The documentation for the PTX instructions can be found under: diff --git a/src/ggml-cuda/mmf.cu b/src/ggml-cuda/mmf.cu index cfa5c5cc..16331e9e 100644 --- a/src/ggml-cuda/mmf.cu +++ b/src/ggml-cuda/mmf.cu @@ -1,343 +1,12 @@ #include "ggml.h" -#include "common.cuh" -#include "mma.cuh" #include "mmf.cuh" -using namespace ggml_cuda_mma; - -#define MMF_ROWS_PER_BLOCK 32 - -template -__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) -static __global__ void mul_mat_f( - const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, - const int ncols, const int nchannels_y, const int stride_row, const int stride_col_y, const int stride_col_dst, - const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, - const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { -#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) - typedef tile<16, 8, T> tile_A; - typedef tile< 8, 8, T> tile_B; - typedef tile<16, 8, float> tile_C; - - constexpr int warp_size = ggml_cuda_get_physical_warp_size(); - constexpr int tile_k_padded = warp_size + 4; - constexpr int ntA = rows_per_block / tile_A::I; - constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I; - - const int row0 = blockIdx.x * rows_per_block; - const int channel_dst = blockIdx.y; - const int channel_x = channel_dst / channel_ratio; - const int channel_y = channel_dst; - const int sample_dst = blockIdx.z; - const int sample_x = sample_dst / sample_ratio; - const int sample_y = sample_dst; - - x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row ; - y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y; - dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst; - - const float2 * y2 = (const float2 *) y; - - extern __shared__ char data_mmv[]; - - tile_C C[ntA][ntB]; - - T * tile_xy = (T *) data_mmv + threadIdx.y*(tile_A::I * tile_k_padded); - - for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) { - tile_A A[ntA][warp_size / tile_A::J]; -#pragma unroll - for (int itA = 0; itA < ntA; ++itA) { -#pragma unroll - for (int i = 0; i < tile_A::I; ++i) { - tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col]; - } -#pragma unroll - for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) { - load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded); - } - } - -#pragma unroll - for (int itB = 0; itB < ntB; ++itB) { - if constexpr (std::is_same_v) { -#pragma unroll - for (int j0 = 0; j0 < tile_B::I; ++j0) { - const int j = j0 + itB*tile_B::I; - - tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f; - } - } else if constexpr (std::is_same_v || std::is_same_v) { -#pragma unroll - for (int j0 = 0; j0 < tile_B::I; ++j0) { - const int j = j0 + itB*tile_B::I; - - const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f); - tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; - } - } else { - static_assert(std::is_same_v, "unsupported type"); - } -#pragma unroll - for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) { - tile_B B; - load_ldmatrix(B, tile_xy + k0, tile_k_padded); -#pragma unroll - for (int itA = 0; itA < ntA; ++itA) { - mma(C[itA][itB], A[itA][k0/tile_B::J], B); - } - } - } - } - - float * buf_iw = (float *) data_mmv; - constexpr int kiw = nwarps*rows_per_block + 4; - - if (nwarps > 1) { - __syncthreads(); - } -#pragma unroll - for (int itB = 0; itB < ntB; ++itB) { -#pragma unroll - for (int itA = 0; itA < ntA; ++itA) { -#pragma unroll - for (int l = 0; l < tile_C::ne; ++l) { - const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l); - const int j = itB*tile_C::J + tile_C::get_j(l); - buf_iw[j*kiw + i] = C[itA][itB].x[l]; - } - } - } - - if (nwarps > 1) { - __syncthreads(); - } - -#pragma unroll - for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - if (j0 + nwarps > cols_per_block && j >= cols_per_block) { - return; - } - - float sum = 0.0f; - static_assert(rows_per_block == warp_size, "need loop/check"); -#pragma unroll - for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) { - const int i = i0 + threadIdx.x; - - sum += buf_iw[j*kiw + i]; - } - dst[j*stride_col_dst + row0 + threadIdx.x] = sum; - } -#else - GGML_UNUSED_VARS(x, y, ids, dst, - ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - NO_DEVICE_CODE; -#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) -} - -template -static void mul_mat_f_cuda( - const T * x, const float * y, const int32_t * ids, float * dst, - const int64_t ncols_x, const int64_t nrows_x, - const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, - const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, - const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, - const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - cudaStream_t stream) { - typedef tile<16, 8, T> tile_A; - typedef tile< 8, 8, T> tile_B; - - GGML_ASSERT(!ids && "mul_mat_id not implemented"); - - GGML_ASSERT(ncols_x % 2 == 0); - GGML_ASSERT(stride_row % 2 == 0); - GGML_ASSERT(stride_col_y % 2 == 0); - GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); - GGML_ASSERT( nsamples_dst % nsamples_x == 0); - const int64_t channel_ratio = nchannels_dst / nchannels_x; - const int64_t sample_ratio = nsamples_dst / nsamples_x; - - const int device = ggml_cuda_get_device(); - const int warp_size = ggml_cuda_info().devices[device].warp_size; - - int64_t nwarps_best = 1; - int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2); - int64_t max_block_size = 256; - for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) { - const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2); - if (niter < niter_best) { - niter_best = niter; - nwarps_best = nwarps; - } - } - - constexpr int rows_per_block = MMF_ROWS_PER_BLOCK; - const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4; - const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4; - const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine); - const dim3 block_nums(nrows_x/rows_per_block, nchannels_dst, nsamples_dst); - const dim3 block_dims(warp_size, nwarps_best, 1); - switch (nwarps_best) { - case 1: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 2: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 3: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 4: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 5: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 6: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 7: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - case 8: { - mul_mat_f<<>> - (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst, - channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, - sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); - } break; - default: { - GGML_ABORT("fatal error"); - } break; - } -} - -template -static void mul_mat_f_switch_cols_per_block( - const T * x, const float * y, const int32_t * ids, float * dst, - const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst, - const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, - const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, - const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, - const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, - cudaStream_t stream) { - switch (ncols_dst) { - case 1: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 2: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 3: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 4: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 5: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 6: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 7: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 8: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 9: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 10: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 11: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 12: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 13: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 14: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 15: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - case 16: { - mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, - nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, - nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); - } break; - default: { - GGML_ABORT("fatal error"); - } break; - } -} - void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { GGML_ASSERT( src1->type == GGML_TYPE_F32); GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); GGML_ASSERT( dst->type == GGML_TYPE_F32); + GGML_TENSOR_BINARY_OP_LOCALS; const size_t ts_src0 = ggml_type_size(src0->type); @@ -365,55 +34,72 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr const int64_t s13 = src1->nb[3] / ts_src1; const int64_t s3 = dst->nb[3] / ts_dst; + const int64_t ids_s0 = ids ? ids->nb[0] / ggml_type_size(ids->type) : 0; + const int64_t ids_s1 = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0; + // For MUL_MAT_ID the memory layout is different than for MUL_MAT: const int64_t ncols_dst = ids ? ne2 : ne1; - const int64_t nchannels_y = ids ? ne11 : ne12; - const int64_t nchannels_dst = ids ? ne1 : ne2; - const int64_t stride_channel_dst = ids ? s1 : s2; - const int64_t stride_channel_y = ids ? s11 : s12; + const int64_t nchannels_dst = ids ? ne1 : ne2; + + const int64_t stride_col_dst = ids ? s2 : s1; + const int64_t stride_col_y = ids ? s12 : s11; + const int64_t stride_channel_dst = ids ? s1 : s2; - GGML_ASSERT(!ids || ncols_dst == 1); + int64_t stride_channel_y = ids ? s11 : s12; + int64_t nchannels_y = ids ? ne11 : ne12; + + //mul_mat_id: handle broadcast + if (ids && nchannels_y == 1) { + stride_channel_y = 0; + nchannels_y = ids->ne[0]; + } switch (src0->type) { case GGML_TYPE_F32: { const float * src0_d = (const float *) src0->data; constexpr int vals_per_T = 1; mul_mat_f_switch_cols_per_block( - src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1, - ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, - ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); + src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, + ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); } break; case GGML_TYPE_F16: { const half2 * src0_d = (const half2 *) src0->data; constexpr int vals_per_T = 2; mul_mat_f_switch_cols_per_block( - src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1, - ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, - ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); + src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, + ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); } break; case GGML_TYPE_BF16: { const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data; constexpr int vals_per_T = 2; mul_mat_f_switch_cols_per_block( - src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1, - ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, - ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); + src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, + ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, + ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); } break; default: GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); } } -bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, int64_t ne11) { +bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, const int src1_ncols) { + + if (ggml_is_quantized(type)) { + return false; + } + if (src0_ne[0] % (warp_size * (4/ggml_type_size(type))) != 0) { return false; } if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) { return false; } - if (ne11 > 16) { + if (src1_ncols > 16) { return false; } + switch (type) { case GGML_TYPE_F32: return ampere_mma_available(cc); diff --git a/src/ggml-cuda/mmf.cuh b/src/ggml-cuda/mmf.cuh index 785f9f21..bf724bc5 100644 --- a/src/ggml-cuda/mmf.cuh +++ b/src/ggml-cuda/mmf.cuh @@ -1,5 +1,473 @@ +#pragma once + +#include "mma.cuh" #include "common.cuh" +using namespace ggml_cuda_mma; + +#define MMF_ROWS_PER_BLOCK 32 + void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); -bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, int64_t ne11); +bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols); + +template +__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1) +static __global__ void mul_mat_f( + const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, + const int ncols, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst, + const int stride_col_id, const int stride_row_id, + const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, + const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) + typedef tile<16, 8, T> tile_A; + typedef tile< 8, 8, T> tile_B; + typedef tile<16, 8, float> tile_C; + + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + constexpr int tile_k_padded = warp_size + 4; + constexpr int ntA = rows_per_block / tile_A::I; + constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I; + + const int row0 = blockIdx.x * rows_per_block; + + const int expert_idx = has_ids ? blockIdx.y : 0; + const int channel_dst = has_ids ? 0 : blockIdx.y; + + const int channel_x = has_ids ? expert_idx : (channel_dst / channel_ratio); + const int channel_y = channel_dst; + const int sample_dst = blockIdx.z; + const int sample_x = sample_dst / sample_ratio; + const int sample_y = sample_dst; + + x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row ; + y += int64_t(sample_y) *stride_sample_y + (has_ids ? 0 : channel_y *stride_channel_y); + dst += int64_t(sample_dst)*stride_sample_dst + (has_ids ? 0 : channel_dst*stride_channel_dst); + + const float2 * y2 = (const float2 *) y; + + extern __shared__ char data_mmv[]; + + char * shmem_base = data_mmv; + int * slot_map = (int *) shmem_base; + char * compute_base = has_ids ? (shmem_base + GGML_PAD(cols_per_block, 16) * sizeof(int)) : shmem_base; + + tile_C C[ntA][ntB]; + + T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded); + + if constexpr (has_ids) { + __shared__ int has_any; + if (threadIdx.y == 0) { + int local_has_any = 0; + for (int j = threadIdx.x; j < cols_per_block; j += warp_size) { + int slot = -1; + for (int k = 0; k < nchannels_dst; ++k) { + const int idv = ids[j*stride_row_id + k*stride_col_id]; + if (idv == expert_idx) { + slot = k; + break; + } + } + if (j < cols_per_block) { + local_has_any |= (slot >= 0); + slot_map[j] = slot; + } + } + has_any = warp_reduce_any(local_has_any); + } + __syncthreads(); + if (has_any == 0) { + return; + } + } + + for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) { + tile_A A[ntA][warp_size / tile_A::J]; +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { +#pragma unroll + for (int i = 0; i < tile_A::I; ++i) { + tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col]; + } +#pragma unroll + for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) { + load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded); + } + } + +#pragma unroll + for (int itB = 0; itB < ntB; ++itB) { + if constexpr (std::is_same_v) { +#pragma unroll + for (int j0 = 0; j0 < tile_B::I; ++j0) { + const int j = j0 + itB*tile_B::I; + + if constexpr (!has_ids) { + tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f; + } else { + float val = 0.0f; + if (j < cols_per_block) { + const int slot = slot_map[j]; + if (slot >= 0) { + val = y[slot*stride_channel_y + j*stride_col_y + col]; + } + } + tile_xy[j0*tile_k_padded + threadIdx.x] = val; + } + } + } else if constexpr (std::is_same_v || std::is_same_v) { +#pragma unroll + for (int j0 = 0; j0 < tile_B::I; ++j0) { + const int j = j0 + itB*tile_B::I; + + if constexpr (!has_ids) { + const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f); + tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; + } else { + float2 tmp = make_float2(0.0f, 0.0f); + if (j < cols_per_block) { + const int slot = slot_map[j]; + if (slot >= 0) { + const float2 * y2_slot = (const float2 *)(y + slot*stride_channel_y); + tmp = y2_slot[j*stride_col_y + col]; + } + } + tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y}; + } + } + } else { + static_assert(std::is_same_v, "unsupported type"); + } +#pragma unroll + for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) { + tile_B B; + load_ldmatrix(B, tile_xy + k0, tile_k_padded); +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { + mma(C[itA][itB], A[itA][k0/tile_B::J], B); + } + } + } + } + + float * buf_iw = (float *) compute_base; + constexpr int kiw = nwarps*rows_per_block + 4; + + if (nwarps > 1) { + __syncthreads(); + } +#pragma unroll + for (int itB = 0; itB < ntB; ++itB) { +#pragma unroll + for (int itA = 0; itA < ntA; ++itA) { +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l); + const int j = itB*tile_C::J + tile_C::get_j(l); + buf_iw[j*kiw + i] = C[itA][itB].x[l]; + } + } + } + + if (nwarps > 1) { + __syncthreads(); + } + +#pragma unroll + for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j0 + nwarps > cols_per_block && j >= cols_per_block) { + return; + } + + float sum = 0.0f; + static_assert(rows_per_block == warp_size, "need loop/check"); +#pragma unroll + for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) { + const int i = i0 + threadIdx.x; + + sum += buf_iw[j*kiw + i]; + } + + if constexpr (!has_ids) { + dst[j*stride_col_dst + row0 + threadIdx.x] = sum; + } else { + const int slot = (j < cols_per_block) ? slot_map[j] : -1; + if (slot >= 0) { + dst[slot*stride_channel_dst + j*stride_col_dst + row0 + threadIdx.x] = sum; + } + } + } +#else + GGML_UNUSED_VARS(x, y, ids, dst, + ncols, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, + channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + NO_DEVICE_CODE; +#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +} + +template +static inline void mul_mat_f_switch_ids( + const T * x, const float * y, const int32_t * ids, float * dst, + const int64_t ncols_x, const int64_t nchannels_dst, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t stride_col_id, const int64_t stride_row_id, + const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, + const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) { + if (ids) { + mul_mat_f<<>> + (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } else { + mul_mat_f<<>> + (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); + } +} + +template +void mul_mat_f_cuda( + const T * x, const float * y, const int32_t * ids, float * dst, + const int64_t ncols_x, const int64_t nrows_x, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t stride_col_id, const int64_t stride_row_id, + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + cudaStream_t stream) { + typedef tile<16, 8, T> tile_A; + typedef tile< 8, 8, T> tile_B; + + GGML_ASSERT(ncols_x % 2 == 0); + GGML_ASSERT(stride_row % 2 == 0); + GGML_ASSERT(stride_col_y % 2 == 0); + GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); + GGML_ASSERT( nsamples_dst % nsamples_x == 0); + const int64_t channel_ratio = nchannels_dst / nchannels_x; + const int64_t sample_ratio = nsamples_dst / nsamples_x; + + const int device = ggml_cuda_get_device(); + const int warp_size = ggml_cuda_info().devices[device].warp_size; + + int64_t nwarps_best = 1; + int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2); + int64_t max_block_size = 256; + for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) { + const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2); + if (niter < niter_best) { + niter_best = niter; + nwarps_best = nwarps; + } + } + + constexpr int rows_per_block = MMF_ROWS_PER_BLOCK; + const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4; + const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4; + const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine); + const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0; + const int nbytes_shared_total = nbytes_shared + nbytes_slotmap; + const int64_t grid_y = ids ? nchannels_x : nchannels_dst; // per expert when ids present + + const dim3 block_nums(nrows_x/rows_per_block, grid_y, nsamples_dst); + const dim3 block_dims(warp_size, nwarps_best, 1); + + switch (nwarps_best) { + case 1: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + case 2: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + case 3: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + case 4: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + case 5: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + case 6: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + case 7: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + case 8: { + mul_mat_f_switch_ids( + x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, + sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream); + } break; + default: { + GGML_ABORT("fatal error"); + } break; + } + + GGML_UNUSED_VARS(nchannels_y); +} + +template +static void mul_mat_f_switch_cols_per_block( + const T * x, const float * y, const int32_t * ids, float * dst, + const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst, + const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, + const int64_t stride_col_id, const int stride_row_id, + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, + cudaStream_t stream) { + switch (ncols_dst) { + case 1: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 2: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 3: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 4: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 5: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 6: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 7: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 8: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 9: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 10: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 11: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 12: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 13: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 14: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 15: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + case 16: { + mul_mat_f_cuda(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst, + stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, + nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); + } break; + default: { + GGML_ABORT("fatal error"); + } break; + } +} + +#define DECL_MMF_CASE_HELPER(T, ncols_dst) \ + template void mul_mat_f_cuda( \ + const T * x, const float * y, const int32_t * ids, float * dst, \ + const int64_t ncols_x, const int64_t nrows_x, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \ + const int64_t stride_col_id, const int64_t stride_row_id, \ + const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \ + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\ + const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \ + cudaStream_t stream); + +#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) +#define DECL_MMF_CASE_EXTERN(ncols_dst) \ + extern DECL_MMF_CASE_HELPER(float, ncols_dst) \ + extern DECL_MMF_CASE_HELPER(half2, ncols_dst) \ + extern DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst) + +#define DECL_MMF_CASE(ncols_dst) \ + DECL_MMF_CASE_HELPER(float, ncols_dst) \ + DECL_MMF_CASE_HELPER(half2, ncols_dst) \ + DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst) + +DECL_MMF_CASE_EXTERN(1); +DECL_MMF_CASE_EXTERN(2); +DECL_MMF_CASE_EXTERN(3); +DECL_MMF_CASE_EXTERN(4); +DECL_MMF_CASE_EXTERN(5); +DECL_MMF_CASE_EXTERN(6); +DECL_MMF_CASE_EXTERN(7); +DECL_MMF_CASE_EXTERN(8); +DECL_MMF_CASE_EXTERN(9); +DECL_MMF_CASE_EXTERN(10); +DECL_MMF_CASE_EXTERN(11); +DECL_MMF_CASE_EXTERN(12); +DECL_MMF_CASE_EXTERN(13); +DECL_MMF_CASE_EXTERN(14); +DECL_MMF_CASE_EXTERN(15); +DECL_MMF_CASE_EXTERN(16); +#else +#define DECL_MMF_CASE(ncols_dst) +#endif diff --git a/src/ggml-cuda/template-instances/generate_cu_files.py b/src/ggml-cuda/template-instances/generate_cu_files.py index a92a9e52..da2d7b7c 100755 --- a/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/src/ggml-cuda/template-instances/generate_cu_files.py @@ -34,6 +34,13 @@ SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do DECL_MMQ_CASE({type}); """ +SOURCE_MMF = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE({type}); +""" + def get_short_name(long_quant_name): return long_quant_name.replace("GGML_TYPE_", "").lower() @@ -76,3 +83,7 @@ for ncols in [8, 16, 32, 64]: for type in TYPES_MMQ: with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f: f.write(SOURCE_MMQ.format(type=type)) + +for type in range(1, 17): + with open(f"mmf-instance-ncols_{type}.cu", "w") as f: + f.write(SOURCE_MMF.format(type=type)) diff --git a/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu b/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu new file mode 100644 index 00000000..f594d5d5 --- /dev/null +++ b/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(1); diff --git a/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu b/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu new file mode 100644 index 00000000..9cc67725 --- /dev/null +++ b/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(10); diff --git a/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu b/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu new file mode 100644 index 00000000..317f487d --- /dev/null +++ b/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(11); diff --git a/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu b/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu new file mode 100644 index 00000000..dc003322 --- /dev/null +++ b/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(12); diff --git a/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu b/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu new file mode 100644 index 00000000..07821017 --- /dev/null +++ b/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(13); diff --git a/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu b/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu new file mode 100644 index 00000000..a23ad6ae --- /dev/null +++ b/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(14); diff --git a/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu b/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu new file mode 100644 index 00000000..0fe3f782 --- /dev/null +++ b/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(15); diff --git a/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu b/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu new file mode 100644 index 00000000..54408637 --- /dev/null +++ b/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(16); diff --git a/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu b/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu new file mode 100644 index 00000000..3b901797 --- /dev/null +++ b/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(2); diff --git a/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu b/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu new file mode 100644 index 00000000..56e940bb --- /dev/null +++ b/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(3); diff --git a/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu b/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu new file mode 100644 index 00000000..a7665d49 --- /dev/null +++ b/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(4); diff --git a/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu b/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu new file mode 100644 index 00000000..3a1dff25 --- /dev/null +++ b/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(5); diff --git a/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu b/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu new file mode 100644 index 00000000..400fb7c6 --- /dev/null +++ b/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(6); diff --git a/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu b/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu new file mode 100644 index 00000000..954a1c7e --- /dev/null +++ b/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(7); diff --git a/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu b/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu new file mode 100644 index 00000000..f1bd09c9 --- /dev/null +++ b/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(8); diff --git a/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu b/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu new file mode 100644 index 00000000..1255ac2a --- /dev/null +++ b/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmf.cuh" + +DECL_MMF_CASE(9); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index cca86a8c..a4776ee7 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6261,7 +6261,7 @@ static std::vector> make_test_cases_eval() { for (int n_mats : {4, 8}) { for (int n_used : {1, 2, 4}) { for (bool b : {false, true}) { - for (int n : {1, 32, 129}) { + for (int n : {1, 4, 5, 32, 129}) { int m = 512; int k = 256; test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, n_used, b, m, n, k));