]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: Add mul_mat_id support for the mmf kernel (llama/15767)
authorAman Gupta <redacted>
Tue, 9 Sep 2025 06:38:02 +0000 (14:38 +0800)
committerGeorgi Gerganov <redacted>
Sat, 20 Sep 2025 10:33:50 +0000 (13:33 +0300)
* CUDA: Add mul_mat_id support the mmf

Add support for mul_mat_id for bs < 16

* Review: use warp_size, fix should_use_mmf condition

* Launch one block per expert, stride along n_expert_used

* templatize mul_mat_id

* Pad shmem to 16 bytes, add helper function mul_mat_f_switch_ids

* Reduce compile times by dividing mmf into f16, bf16 and f32 variants

* Divide mmf by ncols_dst

* Add missing files

* Fix MUSA/HIP builds

23 files changed:
src/ggml-cuda/CMakeLists.txt
src/ggml-cuda/ggml-cuda.cu
src/ggml-cuda/mma.cuh
src/ggml-cuda/mmf.cu
src/ggml-cuda/mmf.cuh
src/ggml-cuda/template-instances/generate_cu_files.py
src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu [new file with mode: 0644]
src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu [new file with mode: 0644]
src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu [new file with mode: 0644]
src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu [new file with mode: 0644]
src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu [new file with mode: 0644]
src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu [new file with mode: 0644]
src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu [new file with mode: 0644]
src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu [new file with mode: 0644]
src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu [new file with mode: 0644]
src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu [new file with mode: 0644]
src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu [new file with mode: 0644]
src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu [new file with mode: 0644]
src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu [new file with mode: 0644]
src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu [new file with mode: 0644]
src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu [new file with mode: 0644]
src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu [new file with mode: 0644]
tests/test-backend-ops.cpp

index 90610af530f2213d9c724f1c329d323242ec5ab5..bdcefe7b7ed7a8e8ab4fdd5e85a609a3df7176d6 100644 (file)
@@ -48,6 +48,8 @@ if (CUDAToolkit_FOUND)
     list(APPEND GGML_SOURCES_CUDA ${SRCS})
     file(GLOB   SRCS "template-instances/mmq*.cu")
     list(APPEND GGML_SOURCES_CUDA ${SRCS})
+    file(GLOB   SRCS "template-instances/mmf*.cu")
+    list(APPEND GGML_SOURCES_CUDA ${SRCS})
 
     if (GGML_CUDA_FA_ALL_QUANTS)
         file(GLOB   SRCS "template-instances/fattn-vec*.cu")
index 95170ae11bad559a5d9a3fbccbbbfb0468ac1313..0f68d685348d27b52ebddaae51d58512d5c43424 100644 (file)
@@ -2109,6 +2109,11 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
             ggml_cuda_mul_mat_q(ctx, src0, src1, ids, dst);
             return;
         }
+
+        if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src1->ne[2])) {
+            ggml_cuda_mul_mat_f(ctx, src0, src1, ids, dst);
+            return;
+        }
     }
 
     cudaStream_t stream = ctx.stream();
index 667deb9c650175e83ad6c27c15f646743a450557..c1f24243fe3883745c6512520ab54fbc9d6cdc55 100644 (file)
@@ -1,3 +1,4 @@
+#pragma once
 // This file contains primitives that expose the tensor core PTX instructions for CUDA code.
 // The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout.
 // The documentation for the PTX instructions can be found under:
index cfa5c5cce2c23dd4b0f3375edccf146955492f7c..16331e9ecfad0d9d89fe8ce0eb1057cdf5d4fdd3 100644 (file)
 #include "ggml.h"
-#include "common.cuh"
-#include "mma.cuh"
 #include "mmf.cuh"
 
-using namespace ggml_cuda_mma;
-
-#define MMF_ROWS_PER_BLOCK 32
-
-template <typename T, int rows_per_block, int cols_per_block, int nwarps>
-__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
-static __global__ void mul_mat_f(
-        const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
-        const int ncols, const int nchannels_y, const int stride_row, const int stride_col_y, const int stride_col_dst,
-        const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
-        const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
-#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
-    typedef tile<16, 8, T>     tile_A;
-    typedef tile< 8, 8, T>     tile_B;
-    typedef tile<16, 8, float> tile_C;
-
-    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-    constexpr int tile_k_padded = warp_size + 4;
-    constexpr int ntA = rows_per_block / tile_A::I;
-    constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
-
-    const int row0        = blockIdx.x * rows_per_block;
-    const int channel_dst = blockIdx.y;
-    const int channel_x   = channel_dst / channel_ratio;
-    const int channel_y   = channel_dst;
-    const int sample_dst  = blockIdx.z;
-    const int sample_x    = sample_dst / sample_ratio;
-    const int sample_y    = sample_dst;
-
-    x   += int64_t(sample_x)  *stride_sample_x   + channel_x  *stride_channel_x   + row0*stride_row ;
-    y   += int64_t(sample_y)  *stride_sample_y   + channel_y  *stride_channel_y;
-    dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
-
-    const float2 * y2 = (const float2 *) y;
-
-    extern __shared__ char data_mmv[];
-
-    tile_C C[ntA][ntB];
-
-    T * tile_xy = (T *) data_mmv + threadIdx.y*(tile_A::I * tile_k_padded);
-
-    for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
-        tile_A A[ntA][warp_size / tile_A::J];
-#pragma unroll
-        for (int itA = 0; itA < ntA; ++itA) {
-#pragma unroll
-            for (int i = 0; i < tile_A::I; ++i) {
-                tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row  + col];
-            }
-#pragma unroll
-            for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
-                load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
-            }
-        }
-
-#pragma unroll
-        for (int itB = 0; itB < ntB; ++itB) {
-            if constexpr (std::is_same_v<T, float>) {
-#pragma unroll
-                for (int j0 = 0; j0 < tile_B::I; ++j0) {
-                    const int j = j0 + itB*tile_B::I;
-
-                    tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
-                }
-            } else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
-#pragma unroll
-                for (int j0 = 0; j0 < tile_B::I; ++j0) {
-                    const int j = j0 + itB*tile_B::I;
-
-                    const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
-                    tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
-                }
-            } else {
-                static_assert(std::is_same_v<T, void>, "unsupported type");
-            }
-#pragma unroll
-            for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
-                tile_B B;
-                load_ldmatrix(B, tile_xy + k0, tile_k_padded);
-#pragma unroll
-                for (int itA = 0; itA < ntA; ++itA) {
-                    mma(C[itA][itB], A[itA][k0/tile_B::J], B);
-                }
-            }
-        }
-    }
-
-    float * buf_iw = (float *) data_mmv;
-    constexpr int kiw = nwarps*rows_per_block + 4;
-
-    if (nwarps > 1) {
-        __syncthreads();
-    }
-#pragma unroll
-    for (int itB = 0; itB < ntB; ++itB) {
-#pragma unroll
-        for (int itA = 0; itA < ntA; ++itA) {
-#pragma unroll
-            for (int l = 0; l < tile_C::ne; ++l) {
-                const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
-                const int j = itB*tile_C::J + tile_C::get_j(l);
-                buf_iw[j*kiw + i] = C[itA][itB].x[l];
-            }
-        }
-    }
-
-    if (nwarps > 1) {
-        __syncthreads();
-    }
-
-#pragma unroll
-    for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
-        const int j = j0 + threadIdx.y;
-
-        if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
-            return;
-        }
-
-        float sum = 0.0f;
-        static_assert(rows_per_block == warp_size, "need loop/check");
-#pragma unroll
-        for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
-            const int i = i0 + threadIdx.x;
-
-            sum += buf_iw[j*kiw + i];
-        }
-        dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
-    }
-#else
-    GGML_UNUSED_VARS(x, y, ids, dst,
-        ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
-        channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-        sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
-    NO_DEVICE_CODE;
-#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
-}
-
-template <typename T, int cols_per_block>
-static void mul_mat_f_cuda(
-        const T * x, const float * y, const int32_t * ids, float * dst,
-        const int64_t ncols_x, const int64_t nrows_x,
-        const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
-        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
-        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
-        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
-        cudaStream_t stream) {
-    typedef tile<16, 8, T>     tile_A;
-    typedef tile< 8, 8, T>     tile_B;
-
-    GGML_ASSERT(!ids && "mul_mat_id not implemented");
-
-    GGML_ASSERT(ncols_x      % 2 == 0);
-    GGML_ASSERT(stride_row   % 2 == 0);
-    GGML_ASSERT(stride_col_y % 2 == 0);
-    GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
-    GGML_ASSERT(       nsamples_dst  % nsamples_x  == 0);
-    const int64_t channel_ratio = nchannels_dst / nchannels_x;
-    const int64_t sample_ratio  = nsamples_dst  / nsamples_x;
-
-    const int device = ggml_cuda_get_device();
-    const int warp_size = ggml_cuda_info().devices[device].warp_size;
-
-    int64_t nwarps_best     = 1;
-    int64_t niter_best      = (ncols_x + warp_size*2 - 1) / (warp_size*2);
-    int64_t max_block_size  = 256;
-    for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
-        const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
-        if (niter < niter_best) {
-            niter_best  = niter;
-            nwarps_best = nwarps;
-        }
-    }
-
-    constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
-    const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4;
-    const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
-    const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
-    const dim3 block_nums(nrows_x/rows_per_block, nchannels_dst, nsamples_dst);
-    const dim3 block_dims(warp_size, nwarps_best, 1);
-    switch (nwarps_best) {
-        case 1: {
-            mul_mat_f<T, rows_per_block, cols_per_block, 1><<<block_nums, block_dims, nbytes_shared, stream>>>
-                (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
-                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
-        } break;
-        case 2: {
-            mul_mat_f<T, rows_per_block, cols_per_block, 2><<<block_nums, block_dims, nbytes_shared, stream>>>
-                (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
-                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
-        } break;
-        case 3: {
-            mul_mat_f<T, rows_per_block, cols_per_block, 3><<<block_nums, block_dims, nbytes_shared, stream>>>
-                (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
-                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
-        } break;
-        case 4: {
-            mul_mat_f<T, rows_per_block, cols_per_block, 4><<<block_nums, block_dims, nbytes_shared, stream>>>
-                (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
-                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
-        } break;
-        case 5: {
-            mul_mat_f<T, rows_per_block, cols_per_block, 5><<<block_nums, block_dims, nbytes_shared, stream>>>
-                (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
-                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
-        } break;
-        case 6: {
-            mul_mat_f<T, rows_per_block, cols_per_block, 6><<<block_nums, block_dims, nbytes_shared, stream>>>
-                (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
-                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
-        } break;
-        case 7: {
-            mul_mat_f<T, rows_per_block, cols_per_block, 7><<<block_nums, block_dims, nbytes_shared, stream>>>
-                (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
-                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
-        } break;
-        case 8: {
-            mul_mat_f<T, rows_per_block, cols_per_block, 8><<<block_nums, block_dims, nbytes_shared, stream>>>
-                (x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
-                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
-        } break;
-        default: {
-            GGML_ABORT("fatal error");
-        } break;
-    }
-}
-
-template <typename T>
-static void mul_mat_f_switch_cols_per_block(
-        const T * x, const float * y, const int32_t * ids, float * dst,
-        const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
-        const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
-        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
-        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
-        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
-        cudaStream_t stream) {
-    switch (ncols_dst) {
-        case  1: {
-            mul_mat_f_cuda<T,  1>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
-                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
-        } break;
-        case  2: {
-            mul_mat_f_cuda<T,  2>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
-                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
-        } break;
-        case  3: {
-            mul_mat_f_cuda<T,  3>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
-                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
-        } break;
-        case  4: {
-            mul_mat_f_cuda<T,  4>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
-                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
-        } break;
-        case  5: {
-            mul_mat_f_cuda<T,  5>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
-                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
-        } break;
-        case  6: {
-            mul_mat_f_cuda<T,  6>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
-                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
-        } break;
-        case  7: {
-            mul_mat_f_cuda<T,  7>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
-                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
-        } break;
-        case  8: {
-            mul_mat_f_cuda<T,  8>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
-                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
-        } break;
-        case  9: {
-            mul_mat_f_cuda<T,  9>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
-                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
-        } break;
-        case 10: {
-            mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
-                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
-        } break;
-        case 11: {
-            mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
-                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
-        } break;
-        case 12: {
-            mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
-                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
-        } break;
-        case 13: {
-            mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
-                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
-        } break;
-        case 14: {
-            mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
-                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
-        } break;
-        case 15: {
-            mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
-                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
-        } break;
-        case 16: {
-            mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
-                nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x,                nsamples_dst,  stride_sample_x,  stride_sample_y,  stride_sample_dst,  stream);
-        } break;
-        default: {
-            GGML_ABORT("fatal error");
-        } break;
-    }
-}
-
 void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
     GGML_ASSERT(        src1->type == GGML_TYPE_F32);
     GGML_ASSERT(!ids ||  ids->type == GGML_TYPE_I32);
     GGML_ASSERT(         dst->type == GGML_TYPE_F32);
 
+
     GGML_TENSOR_BINARY_OP_LOCALS;
 
     const size_t ts_src0 = ggml_type_size(src0->type);
@@ -365,55 +34,72 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
     const int64_t s13 = src1->nb[3] / ts_src1;
     const int64_t s3  =  dst->nb[3] / ts_dst;
 
+    const int64_t ids_s0 = ids ? ids->nb[0] / ggml_type_size(ids->type) : 0;
+    const int64_t ids_s1 = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
+
     // For MUL_MAT_ID the memory layout is different than for MUL_MAT:
     const int64_t ncols_dst          = ids ? ne2  : ne1;
-    const int64_t nchannels_y        = ids ? ne11 : ne12;
-    const int64_t nchannels_dst      = ids ? ne1  : ne2;
-    const int64_t stride_channel_dst = ids ? s1   : s2;
-    const int64_t stride_channel_y   = ids ? s11  : s12;
+    const int64_t nchannels_dst      = ids ? ne1 : ne2;
+
+    const int64_t stride_col_dst     = ids ? s2   : s1;
+    const int64_t stride_col_y       = ids ? s12  : s11;
+    const int64_t stride_channel_dst = ids ? s1 : s2;
 
-    GGML_ASSERT(!ids || ncols_dst == 1);
+    int64_t stride_channel_y         = ids ? s11  : s12;
+    int64_t nchannels_y              = ids ? ne11 : ne12;
+
+    //mul_mat_id: handle broadcast
+    if (ids && nchannels_y == 1) {
+        stride_channel_y = 0;
+        nchannels_y      = ids->ne[0];
+    }
 
     switch (src0->type) {
         case GGML_TYPE_F32: {
             const float * src0_d = (const float *) src0->data;
             constexpr int vals_per_T = 1;
             mul_mat_f_switch_cols_per_block(
-                src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
-                ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
-                ne03,              ne3,           s03/vals_per_T, s13,              s3,                 ctx.stream());
+                src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
+                ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
+                ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
         } break;
         case GGML_TYPE_F16: {
             const half2 * src0_d = (const half2 *) src0->data;
             constexpr int vals_per_T = 2;
             mul_mat_f_switch_cols_per_block(
-                src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
-                ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
-                ne03,              ne3,           s03/vals_per_T, s13,              s3,                 ctx.stream());
+                src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
+                ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
+                ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
         } break;
         case GGML_TYPE_BF16: {
             const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
             constexpr int vals_per_T = 2;
             mul_mat_f_switch_cols_per_block(
-                src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
-                ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
-                ne03,              ne3,           s03/vals_per_T, s13,              s3,                 ctx.stream());
+                src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
+                ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
+                ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
         } break;
         default:
             GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
     }
 }
 
-bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, int64_t ne11) {
+bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, const int src1_ncols) {
+
+    if (ggml_is_quantized(type)) {
+        return false;
+    }
+
     if (src0_ne[0] % (warp_size * (4/ggml_type_size(type))) != 0) {
         return false;
     }
     if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) {
         return false;
     }
-    if (ne11 > 16) {
+    if (src1_ncols > 16) {
         return false;
     }
+
     switch (type) {
         case GGML_TYPE_F32:
             return ampere_mma_available(cc);
index 785f9f211c32a52867732b8f821ad99175023347..bf724bc57b8a027b01f8c5d5f993c9966fea52eb 100644 (file)
@@ -1,5 +1,473 @@
+#pragma once
+
+#include "mma.cuh"
 #include "common.cuh"
 
+using namespace ggml_cuda_mma;
+
+#define MMF_ROWS_PER_BLOCK 32
+
 void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
 
-bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, int64_t ne11);
+bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols);
+
+template <typename T, int rows_per_block, int cols_per_block, int nwarps, bool has_ids>
+__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
+static __global__ void mul_mat_f(
+        const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
+        const int ncols, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
+        const int stride_col_id, const int stride_row_id,
+        const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
+        const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+    typedef tile<16, 8, T>     tile_A;
+    typedef tile< 8, 8, T>     tile_B;
+    typedef tile<16, 8, float> tile_C;
+
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+    constexpr int tile_k_padded = warp_size + 4;
+    constexpr int ntA = rows_per_block / tile_A::I;
+    constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
+
+    const int row0        = blockIdx.x * rows_per_block;
+
+    const int expert_idx  = has_ids ? blockIdx.y : 0;
+    const int channel_dst = has_ids ? 0 : blockIdx.y;
+
+    const int channel_x   = has_ids ? expert_idx : (channel_dst / channel_ratio);
+    const int channel_y   = channel_dst;
+    const int sample_dst  = blockIdx.z;
+    const int sample_x    = sample_dst / sample_ratio;
+    const int sample_y    = sample_dst;
+
+    x   += int64_t(sample_x)  *stride_sample_x   + channel_x  *stride_channel_x  + row0*stride_row ;
+    y   += int64_t(sample_y)  *stride_sample_y   + (has_ids ? 0 : channel_y  *stride_channel_y);
+    dst += int64_t(sample_dst)*stride_sample_dst + (has_ids ? 0 : channel_dst*stride_channel_dst);
+
+    const float2 * y2 = (const float2 *) y;
+
+    extern __shared__ char data_mmv[];
+
+    char * shmem_base = data_mmv;
+    int  * slot_map   = (int *) shmem_base;
+    char * compute_base = has_ids ? (shmem_base + GGML_PAD(cols_per_block, 16) * sizeof(int)) : shmem_base;
+
+    tile_C C[ntA][ntB];
+
+    T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);
+
+    if constexpr (has_ids) {
+        __shared__ int has_any;
+        if (threadIdx.y == 0) {
+            int local_has_any = 0;
+            for (int j = threadIdx.x; j < cols_per_block; j += warp_size) {
+                int slot = -1;
+                for (int k = 0; k < nchannels_dst; ++k) {
+                    const int idv = ids[j*stride_row_id + k*stride_col_id];
+                    if (idv == expert_idx) {
+                        slot = k;
+                        break;
+                    }
+                }
+                if (j < cols_per_block) {
+                    local_has_any |= (slot >= 0);
+                    slot_map[j] = slot;
+                }
+            }
+            has_any = warp_reduce_any(local_has_any);
+        }
+        __syncthreads();
+        if (has_any == 0) {
+            return;
+        }
+    }
+
+    for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
+        tile_A A[ntA][warp_size / tile_A::J];
+#pragma unroll
+        for (int itA = 0; itA < ntA; ++itA) {
+#pragma unroll
+            for (int i = 0; i < tile_A::I; ++i) {
+                tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row  + col];
+            }
+#pragma unroll
+            for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
+                load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
+            }
+        }
+
+#pragma unroll
+        for (int itB = 0; itB < ntB; ++itB) {
+            if constexpr (std::is_same_v<T, float>) {
+#pragma unroll
+                for (int j0 = 0; j0 < tile_B::I; ++j0) {
+                    const int j = j0 + itB*tile_B::I;
+
+                    if constexpr (!has_ids) {
+                        tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
+                    } else {
+                        float val = 0.0f;
+                        if (j < cols_per_block) {
+                            const int slot = slot_map[j];
+                            if (slot >= 0) {
+                                val = y[slot*stride_channel_y + j*stride_col_y + col];
+                            }
+                        }
+                        tile_xy[j0*tile_k_padded + threadIdx.x] = val;
+                    }
+                }
+            } else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
+#pragma unroll
+                for (int j0 = 0; j0 < tile_B::I; ++j0) {
+                    const int j = j0 + itB*tile_B::I;
+
+                    if constexpr (!has_ids) {
+                        const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
+                        tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
+                    } else {
+                        float2 tmp = make_float2(0.0f, 0.0f);
+                        if (j < cols_per_block) {
+                            const int slot = slot_map[j];
+                            if (slot >= 0) {
+                                const float2 * y2_slot = (const float2 *)(y + slot*stride_channel_y);
+                                tmp = y2_slot[j*stride_col_y + col];
+                            }
+                        }
+                        tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
+                    }
+                }
+            } else {
+                static_assert(std::is_same_v<T, void>, "unsupported type");
+            }
+#pragma unroll
+            for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
+                tile_B B;
+                load_ldmatrix(B, tile_xy + k0, tile_k_padded);
+#pragma unroll
+                for (int itA = 0; itA < ntA; ++itA) {
+                    mma(C[itA][itB], A[itA][k0/tile_B::J], B);
+                }
+            }
+        }
+    }
+
+    float * buf_iw = (float *) compute_base;
+    constexpr int kiw = nwarps*rows_per_block + 4;
+
+    if (nwarps > 1) {
+        __syncthreads();
+    }
+#pragma unroll
+    for (int itB = 0; itB < ntB; ++itB) {
+#pragma unroll
+        for (int itA = 0; itA < ntA; ++itA) {
+#pragma unroll
+            for (int l = 0; l < tile_C::ne; ++l) {
+                const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
+                const int j = itB*tile_C::J + tile_C::get_j(l);
+                buf_iw[j*kiw + i] = C[itA][itB].x[l];
+            }
+        }
+    }
+
+    if (nwarps > 1) {
+        __syncthreads();
+    }
+
+#pragma unroll
+    for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+
+        if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
+            return;
+        }
+
+        float sum = 0.0f;
+        static_assert(rows_per_block == warp_size, "need loop/check");
+#pragma unroll
+        for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
+            const int i = i0 + threadIdx.x;
+
+            sum += buf_iw[j*kiw + i];
+        }
+
+        if constexpr (!has_ids) {
+            dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
+        } else {
+            const int slot = (j < cols_per_block) ? slot_map[j] : -1;
+            if (slot >= 0) {
+                dst[slot*stride_channel_dst + j*stride_col_dst + row0 + threadIdx.x] = sum;
+            }
+        }
+    }
+#else
+    GGML_UNUSED_VARS(x, y, ids, dst,
+        ncols, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+        stride_col_id, stride_row_id,
+        channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+        sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+    NO_DEVICE_CODE;
+#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+}
+
+template<typename T, int cols_per_block, int nwarps>
+static inline void mul_mat_f_switch_ids(
+        const T * x, const float * y, const int32_t * ids, float * dst,
+        const int64_t ncols_x, const int64_t nchannels_dst,
+        const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
+        const int64_t stride_col_id, const int64_t stride_row_id,
+        const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
+        const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+        const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) {
+    if (ids) {
+        mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums, block_dims, nbytes_shared_total, stream>>>
+            (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+             stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+             sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+    } else {
+        mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
+            (x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+             stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+             sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
+    }
+}
+
+template <typename T, int cols_per_block>
+void mul_mat_f_cuda(
+        const T * x, const float * y, const int32_t * ids, float * dst,
+        const int64_t ncols_x, const int64_t nrows_x, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
+        const int64_t stride_col_id, const int64_t stride_row_id,
+        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
+        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
+        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+        cudaStream_t stream) {
+    typedef tile<16, 8, T>     tile_A;
+    typedef tile< 8, 8, T>     tile_B;
+
+    GGML_ASSERT(ncols_x      % 2 == 0);
+    GGML_ASSERT(stride_row   % 2 == 0);
+    GGML_ASSERT(stride_col_y % 2 == 0);
+    GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
+    GGML_ASSERT(       nsamples_dst  % nsamples_x  == 0);
+    const int64_t channel_ratio = nchannels_dst / nchannels_x;
+    const int64_t sample_ratio  = nsamples_dst  / nsamples_x;
+
+    const int device = ggml_cuda_get_device();
+    const int warp_size = ggml_cuda_info().devices[device].warp_size;
+
+    int64_t nwarps_best     = 1;
+    int64_t niter_best      = (ncols_x + warp_size*2 - 1) / (warp_size*2);
+    int64_t max_block_size  = 256;
+    for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
+        const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
+        if (niter < niter_best) {
+            niter_best  = niter;
+            nwarps_best = nwarps;
+        }
+    }
+
+    constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
+    const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4;
+    const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
+    const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
+    const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
+    const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
+    const int64_t grid_y = ids ? nchannels_x : nchannels_dst; // per expert when ids present
+
+    const dim3 block_nums(nrows_x/rows_per_block, grid_y, nsamples_dst);
+    const dim3 block_dims(warp_size, nwarps_best, 1);
+
+    switch (nwarps_best) {
+        case 1: {
+            mul_mat_f_switch_ids<T, cols_per_block, 1>(
+                x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+                stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+        } break;
+        case 2: {
+            mul_mat_f_switch_ids<T, cols_per_block, 2>(
+                x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+                stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+        } break;
+        case 3: {
+            mul_mat_f_switch_ids<T, cols_per_block, 3>(
+                x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+                stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+        } break;
+        case 4: {
+            mul_mat_f_switch_ids<T, cols_per_block, 4>(
+                x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+                stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+        } break;
+        case 5: {
+            mul_mat_f_switch_ids<T, cols_per_block, 5>(
+                x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+                stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+        } break;
+        case 6: {
+            mul_mat_f_switch_ids<T, cols_per_block, 6>(
+                x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+                stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+        } break;
+        case 7: {
+            mul_mat_f_switch_ids<T, cols_per_block, 7>(
+                x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+                stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+        } break;
+        case 8: {
+            mul_mat_f_switch_ids<T, cols_per_block, 8>(
+                x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+                stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+        } break;
+        default: {
+            GGML_ABORT("fatal error");
+        } break;
+    }
+
+    GGML_UNUSED_VARS(nchannels_y);
+}
+
+template <typename T>
+static void mul_mat_f_switch_cols_per_block(
+        const T * x, const float * y, const int32_t * ids, float * dst,
+        const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
+        const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
+        const int64_t stride_col_id, const int stride_row_id,
+        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
+        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
+        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
+        cudaStream_t stream) {
+    switch (ncols_dst) {
+        case  1: {
+            mul_mat_f_cuda<T,  1>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+        } break;
+        case  2: {
+            mul_mat_f_cuda<T,  2>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+        } break;
+        case  3: {
+            mul_mat_f_cuda<T,  3>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+        } break;
+        case  4: {
+            mul_mat_f_cuda<T,  4>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+        } break;
+        case  5: {
+            mul_mat_f_cuda<T,  5>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y,  stride_sample_dst, stream);
+        } break;
+        case  6: {
+            mul_mat_f_cuda<T,  6>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+        } break;
+        case  7: {
+            mul_mat_f_cuda<T,  7>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+        } break;
+        case  8: {
+            mul_mat_f_cuda<T,  8>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+        } break;
+        case  9: {
+            mul_mat_f_cuda<T,  9>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+        } break;
+        case 10: {
+            mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+        } break;
+        case 11: {
+            mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+        } break;
+        case 12: {
+            mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+        } break;
+        case 13: {
+            mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+        } break;
+        case 14: {
+            mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+        } break;
+        case 15: {
+            mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+        } break;
+        case 16: {
+            mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
+                stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+        } break;
+        default: {
+            GGML_ABORT("fatal error");
+        } break;
+    }
+}
+
+#define DECL_MMF_CASE_HELPER(T, ncols_dst) \
+    template void mul_mat_f_cuda<T, ncols_dst>( \
+        const T * x, const float * y, const int32_t * ids, float * dst, \
+        const int64_t ncols_x, const int64_t nrows_x, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \
+        const int64_t stride_col_id, const int64_t stride_row_id, \
+        const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \
+        const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\
+        const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \
+        cudaStream_t stream);
+
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+#define DECL_MMF_CASE_EXTERN(ncols_dst) \
+    extern DECL_MMF_CASE_HELPER(float, ncols_dst) \
+    extern DECL_MMF_CASE_HELPER(half2, ncols_dst) \
+    extern DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
+
+#define DECL_MMF_CASE(ncols_dst) \
+    DECL_MMF_CASE_HELPER(float, ncols_dst) \
+    DECL_MMF_CASE_HELPER(half2, ncols_dst) \
+    DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
+
+DECL_MMF_CASE_EXTERN(1);
+DECL_MMF_CASE_EXTERN(2);
+DECL_MMF_CASE_EXTERN(3);
+DECL_MMF_CASE_EXTERN(4);
+DECL_MMF_CASE_EXTERN(5);
+DECL_MMF_CASE_EXTERN(6);
+DECL_MMF_CASE_EXTERN(7);
+DECL_MMF_CASE_EXTERN(8);
+DECL_MMF_CASE_EXTERN(9);
+DECL_MMF_CASE_EXTERN(10);
+DECL_MMF_CASE_EXTERN(11);
+DECL_MMF_CASE_EXTERN(12);
+DECL_MMF_CASE_EXTERN(13);
+DECL_MMF_CASE_EXTERN(14);
+DECL_MMF_CASE_EXTERN(15);
+DECL_MMF_CASE_EXTERN(16);
+#else
+#define DECL_MMF_CASE(ncols_dst)
+#endif
index a92a9e52cf369afbcdfdebf055640dbbee9c5e59..da2d7b7c3b38f7538b21eedfa1bc0e1d9f4614aa 100755 (executable)
@@ -34,6 +34,13 @@ SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do
 DECL_MMQ_CASE({type});
 """
 
+SOURCE_MMF = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE({type});
+"""
+
 
 def get_short_name(long_quant_name):
     return long_quant_name.replace("GGML_TYPE_", "").lower()
@@ -76,3 +83,7 @@ for ncols in [8, 16, 32, 64]:
 for type in TYPES_MMQ:
     with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:
         f.write(SOURCE_MMQ.format(type=type))
+
+for type in range(1, 17):
+    with open(f"mmf-instance-ncols_{type}.cu", "w") as f:
+        f.write(SOURCE_MMF.format(type=type))
diff --git a/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu b/src/ggml-cuda/template-instances/mmf-instance-ncols_1.cu
new file mode 100644 (file)
index 0000000..f594d5d
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(1);
diff --git a/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu b/src/ggml-cuda/template-instances/mmf-instance-ncols_10.cu
new file mode 100644 (file)
index 0000000..9cc6772
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(10);
diff --git a/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu b/src/ggml-cuda/template-instances/mmf-instance-ncols_11.cu
new file mode 100644 (file)
index 0000000..317f487
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(11);
diff --git a/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu b/src/ggml-cuda/template-instances/mmf-instance-ncols_12.cu
new file mode 100644 (file)
index 0000000..dc00332
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(12);
diff --git a/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu b/src/ggml-cuda/template-instances/mmf-instance-ncols_13.cu
new file mode 100644 (file)
index 0000000..0782101
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(13);
diff --git a/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu b/src/ggml-cuda/template-instances/mmf-instance-ncols_14.cu
new file mode 100644 (file)
index 0000000..a23ad6a
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(14);
diff --git a/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu b/src/ggml-cuda/template-instances/mmf-instance-ncols_15.cu
new file mode 100644 (file)
index 0000000..0fe3f78
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(15);
diff --git a/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu b/src/ggml-cuda/template-instances/mmf-instance-ncols_16.cu
new file mode 100644 (file)
index 0000000..5440863
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(16);
diff --git a/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu b/src/ggml-cuda/template-instances/mmf-instance-ncols_2.cu
new file mode 100644 (file)
index 0000000..3b90179
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(2);
diff --git a/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu b/src/ggml-cuda/template-instances/mmf-instance-ncols_3.cu
new file mode 100644 (file)
index 0000000..56e940b
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(3);
diff --git a/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu b/src/ggml-cuda/template-instances/mmf-instance-ncols_4.cu
new file mode 100644 (file)
index 0000000..a7665d4
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(4);
diff --git a/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu b/src/ggml-cuda/template-instances/mmf-instance-ncols_5.cu
new file mode 100644 (file)
index 0000000..3a1dff2
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(5);
diff --git a/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu b/src/ggml-cuda/template-instances/mmf-instance-ncols_6.cu
new file mode 100644 (file)
index 0000000..400fb7c
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(6);
diff --git a/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu b/src/ggml-cuda/template-instances/mmf-instance-ncols_7.cu
new file mode 100644 (file)
index 0000000..954a1c7
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(7);
diff --git a/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu b/src/ggml-cuda/template-instances/mmf-instance-ncols_8.cu
new file mode 100644 (file)
index 0000000..f1bd09c
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(8);
diff --git a/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu b/src/ggml-cuda/template-instances/mmf-instance-ncols_9.cu
new file mode 100644 (file)
index 0000000..1255ac2
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmf.cuh"
+
+DECL_MMF_CASE(9);
index cca86a8ce8842b2adb95033ad2e68501b317c10a..a4776ee7f15dab6d416f81cc5d3f6755159e653a 100644 (file)
@@ -6261,7 +6261,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
             for (int n_mats : {4, 8}) {
                 for (int n_used : {1, 2, 4}) {
                     for (bool b : {false, true}) {
-                        for (int n : {1, 32, 129}) {
+                        for (int n : {1, 4, 5, 32, 129}) {
                             int m = 512;
                             int k = 256;
                             test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, n_used, b, m, n, k));