]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: add fp kernel for larger batch size MoE (llama/16512)
authorAman Gupta <redacted>
Tue, 14 Oct 2025 11:15:15 +0000 (19:15 +0800)
committerGeorgi Gerganov <redacted>
Wed, 15 Oct 2025 06:29:17 +0000 (09:29 +0300)
* CUDA: kernel for larger batch sizes for MoE

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* fixup

* tests

* Move mmq_ids_helper to mmid

* cleanup

* Remove redundant checks

ggml/src/ggml-cuda/mmf.cu
ggml/src/ggml-cuda/mmf.cuh
ggml/src/ggml-cuda/mmid.cu [new file with mode: 0644]
ggml/src/ggml-cuda/mmid.cuh [new file with mode: 0644]
ggml/src/ggml-cuda/mmq.cu

index 599e085ee91b7e058ea7184bab002064f38b5a4b..9e2aaf52d6ccefc2ad0122eb6ab097dc191bec3d 100644 (file)
@@ -1,5 +1,7 @@
 #include "ggml.h"
 #include "mmf.cuh"
+#include "mmid.cuh"
+
 
 void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
     GGML_ASSERT(        src1->type == GGML_TYPE_F32);
@@ -37,6 +39,12 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
     const int64_t ids_s0 = ids ? ids->nb[0] / ggml_type_size(ids->type) : 0;
     const int64_t ids_s1 = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
 
+    mmf_ids_data ids_info{};
+    mmf_ids_data * ids_info_ptr = nullptr;
+    ggml_cuda_pool_alloc<int32_t> ids_src_compact_dev;
+    ggml_cuda_pool_alloc<int32_t> ids_dst_compact_dev;
+    ggml_cuda_pool_alloc<int32_t> expert_bounds_dev;
+
     // For MUL_MAT_ID the memory layout is different than for MUL_MAT:
     const int64_t ncols_dst          = ids ? ne2  : ne1;
     const int64_t nchannels_dst      = ids ? ne1 : ne2;
@@ -54,6 +62,33 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
         nchannels_y      = ids->ne[0];
     }
 
+    if (ids && ncols_dst > 16) {
+        const int64_t n_expert_used = ids->ne[0];
+        const int64_t n_experts     = ne02;
+        const int64_t n_tokens      = ne12;
+        const int64_t ne_get_rows   = n_tokens * n_expert_used;
+
+        ids_src_compact_dev.alloc(ctx.pool(), ne_get_rows);
+        ids_dst_compact_dev.alloc(ctx.pool(), ne_get_rows);
+        expert_bounds_dev.alloc(ctx.pool(), n_experts + 1);
+
+        const int si1  = static_cast<int>(ids_s1);
+        const int sis1 = static_cast<int>(src1->nb[2] / src1->nb[1]);
+
+        GGML_ASSERT(sis1 > 0);
+
+        ggml_cuda_launch_mm_ids_helper(ids_d, ids_src_compact_dev.get(), ids_dst_compact_dev.get(), expert_bounds_dev.get(),
+            static_cast<int>(n_experts), static_cast<int>(n_tokens), static_cast<int>(n_expert_used), static_cast<int>(ne11), si1, sis1, ctx.stream());
+        CUDA_CHECK(cudaGetLastError());
+
+        ids_info.ids_src_compact   = ids_src_compact_dev.get();
+        ids_info.ids_dst_compact   = ids_dst_compact_dev.get();
+        ids_info.expert_bounds_dev = expert_bounds_dev.get();
+        ids_info.n_experts         = static_cast<int>(n_experts);
+        ids_info.sis1              = sis1;
+        ids_info_ptr = &ids_info;
+    }
+
     switch (src0->type) {
         case GGML_TYPE_F32: {
             const float * src0_d = (const float *) src0->data;
@@ -61,7 +96,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
             mul_mat_f_switch_cols_per_block(
                 src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
                 ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
-                ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
+                ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
         } break;
         case GGML_TYPE_F16: {
             const half2 * src0_d = (const half2 *) src0->data;
@@ -69,7 +104,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
             mul_mat_f_switch_cols_per_block(
                 src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
                 ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
-                ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
+                ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
         } break;
         case GGML_TYPE_BF16: {
             const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
@@ -77,7 +112,7 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
             mul_mat_f_switch_cols_per_block(
                 src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
                 ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
-                ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
+                ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
         } break;
         default:
             GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
@@ -98,10 +133,9 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
     }
 
     if (mul_mat_id) {
-        if (type == GGML_TYPE_F32 && src1_ncols > 32) {
+        if (src0_ne[1] <= 1024 && src1_ncols > 512) {
             return false;
-        }
-        if ((type == GGML_TYPE_F16 || type == GGML_TYPE_BF16) && src1_ncols > 64) {
+        } else if(src0_ne[1] > 1024 && src1_ncols > 128) {
             return false;
         }
     } else {
index a6c3adfcf1704a99538805898f148dff3d11eb27..49d5295be0ea0515d78aeb011129aa8da6648066 100644 (file)
@@ -7,6 +7,14 @@ using namespace ggml_cuda_mma;
 
 #define MMF_ROWS_PER_BLOCK 32
 
+struct mmf_ids_data {
+    const int32_t * ids_src_compact = nullptr;
+    const int32_t * ids_dst_compact = nullptr;
+    const int32_t * expert_bounds_dev = nullptr;
+    int n_experts = 0;
+    int sis1 = 0;
+};
+
 void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
 
 bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols, bool mul_mat_id);
@@ -224,6 +232,250 @@ static __global__ void mul_mat_f(
 #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
 }
 
+
+//This kernel is for larger batch sizes of mul_mat_id
+template <typename T, int rows_per_block, int cols_per_block, int nwarps>
+__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
+static __global__ void mul_mat_f_ids(
+        const T * __restrict__ x, const float * __restrict__ y,
+        const int32_t * __restrict__ ids_src_compact, const int32_t * __restrict__ ids_dst_compact,
+        const int32_t * __restrict__ expert_bounds, float * __restrict__ dst,
+        const int ncols, const int ncols_dst_total, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
+        const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
+        const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
+        const uint3 sis1_fd, const uint3 nch_fd) {
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+    typedef tile<16, 8, T>     tile_A;
+    typedef tile< 8, 8, T>     tile_B;
+    typedef tile<16, 8, float> tile_C;
+
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+    constexpr int tile_k_padded = warp_size + 4;
+    constexpr int ntA = rows_per_block / tile_A::I;
+    constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
+
+    const int row0        = blockIdx.x * rows_per_block;
+
+    const int expert_idx = blockIdx.y;
+    const int expert_start = expert_bounds[expert_idx];
+    const int expert_end   = expert_bounds[expert_idx + 1];
+    const int ncols_expert = expert_end - expert_start;
+
+    const int tiles_for_expert = (ncols_expert + cols_per_block - 1) / cols_per_block;
+    const int tile_idx = blockIdx.z;
+    if (tile_idx >= tiles_for_expert) {
+        return;
+    }
+
+    const int col_base = tile_idx * cols_per_block;
+
+    GGML_UNUSED(channel_ratio);
+
+    const int channel_x   = expert_idx;
+    const int sample_dst  = 0;
+    const int sample_x    = sample_dst / sample_ratio;
+    const int sample_y    = sample_dst;
+
+    x   += int64_t(sample_x)  *stride_sample_x   + channel_x  *stride_channel_x  + row0*stride_row;
+    y   += int64_t(sample_y)  *stride_sample_y;
+    dst += int64_t(sample_dst)*stride_sample_dst;
+
+    const int32_t * ids_src_expert = ids_src_compact + expert_start;
+    const int32_t * ids_dst_expert = ids_dst_compact + expert_start;
+
+    extern __shared__ char data_mmv[];
+    char * compute_base = data_mmv;
+
+    //const float2 * y2 = (const float2 *) y;
+
+    tile_C C[ntA][ntB];
+
+    T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);
+
+    for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
+        tile_A A[ntA][warp_size / tile_A::J];
+#pragma unroll
+        for (int itA = 0; itA < ntA; ++itA) {
+#pragma unroll
+            for (int i = 0; i < tile_A::I; ++i) {
+                tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row  + col];
+            }
+#pragma unroll
+            for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
+                load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
+            }
+        }
+
+        if constexpr (std::is_same_v<T, float>) {
+            float vals_buf[2][tile_B::I];
+            auto gather_tile = [&](int tile_idx_local, float *vals) {
+#pragma unroll
+                for (int j0 = 0; j0 < tile_B::I; ++j0) {
+                    const int j = j0 + tile_idx_local*tile_B::I;
+                    const int global_j = col_base + j;
+                    float val = 0.0f;
+                    if (j < cols_per_block && global_j < ncols_expert) {
+                        const int src_entry = ids_src_expert[global_j];
+                        const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd);
+                        const int token   = (int) qrm.x;
+                        const int channel = (int) qrm.y;
+                        if (token < ncols_dst_total) {
+                            val = y[channel*stride_channel_y + token*stride_col_y + col];
+                        }
+                    }
+                    vals[j0] = val;
+                }
+            };
+
+            gather_tile(0, vals_buf[0]);
+
+            int curr_buf = 0;
+            int next_buf = 1;
+#pragma unroll
+            for (int itB = 0; itB < ntB; ++itB) {
+#pragma unroll
+                for (int j0 = 0; j0 < tile_B::I; ++j0) {
+                    tile_xy[j0*tile_k_padded + threadIdx.x] = vals_buf[curr_buf][j0];
+                }
+
+                if (itB + 1 < ntB) {
+                    gather_tile(itB + 1, vals_buf[next_buf]);
+                }
+
+#pragma unroll
+                for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
+                    tile_B B;
+                    load_ldmatrix(B, tile_xy + k0, tile_k_padded);
+#pragma unroll
+                    for (int itA = 0; itA < ntA; ++itA) {
+                        mma(C[itA][itB], A[itA][k0/tile_B::J], B);
+                    }
+                }
+
+                if (itB + 1 < ntB) {
+                    curr_buf ^= 1;
+                    next_buf ^= 1;
+                }
+            }
+        } else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
+            float2 vals_buf[2][tile_B::I];
+            auto gather_tile = [&](int tile_idx_local, float2 *vals) {
+#pragma unroll
+                for (int j0 = 0; j0 < tile_B::I; ++j0) {
+                    const int j = j0 + tile_idx_local*tile_B::I;
+                    const int global_j = col_base + j;
+                    float2 tmp = make_float2(0.0f, 0.0f);
+                    if (j < cols_per_block && global_j < ncols_expert) {
+                        const int src_entry = ids_src_expert[global_j];
+                        const uint2 qrm = fast_div_modulo((uint32_t) src_entry, sis1_fd);
+                        const int token   = (int) qrm.x;
+                        const int channel = (int) qrm.y;
+                        if (token < ncols_dst_total) {
+                            tmp = *(const float2*) &y[channel*stride_channel_y + 2*(token*stride_col_y + col)];
+                        }
+                    }
+                    vals[j0] = tmp;
+                }
+            };
+
+            if (ntB > 0) {
+                gather_tile(0, vals_buf[0]);
+            }
+
+            int curr_buf = 0;
+            int next_buf = 1;
+#pragma unroll
+            for (int itB = 0; itB < ntB; ++itB) {
+#pragma unroll
+                for (int j0 = 0; j0 < tile_B::I; ++j0) {
+                    const float2 tmp = vals_buf[curr_buf][j0];
+                    tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
+                }
+
+                if (itB + 1 < ntB) {
+                    gather_tile(itB + 1, vals_buf[next_buf]);
+                }
+
+#pragma unroll
+                for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
+                    tile_B B;
+                    load_ldmatrix(B, tile_xy + k0, tile_k_padded);
+#pragma unroll
+                    for (int itA = 0; itA < ntA; ++itA) {
+                        mma(C[itA][itB], A[itA][k0/tile_B::J], B);
+                    }
+                }
+
+                if (itB + 1 < ntB) {
+                    curr_buf ^= 1;
+                    next_buf ^= 1;
+                }
+            }
+        } else {
+            static_assert(std::is_same_v<T, void>, "unsupported type");
+        }
+    }
+
+    float * buf_iw = (float *) compute_base;
+    constexpr int kiw = nwarps*rows_per_block + 4;
+
+    if (nwarps > 1) {
+        __syncthreads();
+    }
+#pragma unroll
+    for (int itB = 0; itB < ntB; ++itB) {
+#pragma unroll
+        for (int itA = 0; itA < ntA; ++itA) {
+#pragma unroll
+            for (int l = 0; l < tile_C::ne; ++l) {
+                const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
+                const int j = itB*tile_C::J + tile_C::get_j(l);
+                buf_iw[j*kiw + i] = C[itA][itB].x[l];
+            }
+        }
+    }
+
+    if (nwarps > 1) {
+        __syncthreads();
+    }
+
+#pragma unroll
+    for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
+        const int j = j0 + threadIdx.y;
+
+        if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
+            return;
+        }
+
+        float sum = 0.0f;
+        static_assert(rows_per_block == warp_size, "need loop/check");
+#pragma unroll
+        for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
+            const int i = i0 + threadIdx.x;
+
+            sum += buf_iw[j*kiw + i];
+        }
+
+        const int global_j = col_base + j;
+        if (j < cols_per_block && global_j < ncols_expert && nchannels_dst > 0) {
+            const int dst_entry = ids_dst_expert[global_j];
+            const uint2 qrm = fast_div_modulo((uint32_t) dst_entry, nch_fd);
+            const int token = (int) qrm.x;
+            if (token < ncols_dst_total) {
+                const int slot = (int) qrm.y;
+                dst[slot*stride_channel_dst + token*stride_col_dst + row0 + threadIdx.x] = sum;
+            }
+        }
+    }
+#else
+    GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
+        ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+        channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+        sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
+    NO_DEVICE_CODE;
+#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+}
+
 template<typename T, int cols_per_block, int nwarps>
 static inline void mul_mat_f_switch_ids(
         const T * x, const float * y, const int32_t * ids, float * dst,
@@ -232,13 +484,35 @@ static inline void mul_mat_f_switch_ids(
         const int64_t stride_col_id, const int64_t stride_row_id,
         const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
         const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
-        const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) {
-    if (ids) {
+        const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream,
+        const mmf_ids_data * ids_data) {
+    const bool has_ids_data = ids_data && ids_data->ids_src_compact;
+
+    // Use the compact-ids kernel only for larger tiles; for small ncols_dst (< 16)
+    // we prefer the normal mul_mat_f path with has_ids=true.
+    if (has_ids_data && ncols_dst > 16) {
+        const int max_tiles = (int) ((ncols_dst + cols_per_block - 1) / cols_per_block);
+        if (max_tiles == 0) {
+            return;
+        }
+        dim3 block_nums_ids(block_nums.x, ids_data->n_experts, max_tiles);
+
+        const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1);
+        const uint3 nch_fd  = init_fastdiv_values((uint32_t) nchannels_dst);
+
+        mul_mat_f_ids<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
+            (x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst,
+            ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+            channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
+            sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst,
+            sis1_fd, nch_fd);
+    } else if (ids) {
         const int64_t col_tiles = (ncols_dst + cols_per_block - 1) / cols_per_block;
         dim3 block_nums_ids = block_nums;
         block_nums_ids.y *= col_tiles;
+
         mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
-             (x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
+            (x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
              stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
              sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
     } else {
@@ -258,7 +532,7 @@ void mul_mat_f_cuda(
         const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
         const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
-        cudaStream_t stream) {
+        cudaStream_t stream, const mmf_ids_data * ids_data) {
     typedef tile<16, 8, T>     tile_A;
     typedef tile< 8, 8, T>     tile_B;
 
@@ -290,7 +564,7 @@ void mul_mat_f_cuda(
     const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
     const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
     const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
-    const int64_t grid_y = ids ? nchannels_x : nchannels_dst; // per expert when ids present
+    const int64_t grid_y = ids ? nchannels_x : nchannels_dst;
 
     const dim3 block_nums(nrows_x/rows_per_block, grid_y, nsamples_dst);
     const dim3 block_dims(warp_size, nwarps_best, 1);
@@ -300,49 +574,57 @@ void mul_mat_f_cuda(
             mul_mat_f_switch_ids<T, cols_per_block, 1>(
                 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
+                ids_data);
         } break;
         case 2: {
             mul_mat_f_switch_ids<T, cols_per_block, 2>(
                 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
+                ids_data);
         } break;
         case 3: {
             mul_mat_f_switch_ids<T, cols_per_block, 3>(
                 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
+                ids_data);
         } break;
         case 4: {
             mul_mat_f_switch_ids<T, cols_per_block, 4>(
                 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
+                ids_data);
         } break;
         case 5: {
             mul_mat_f_switch_ids<T, cols_per_block, 5>(
                 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
+                ids_data);
         } break;
         case 6: {
             mul_mat_f_switch_ids<T, cols_per_block, 6>(
                 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
+                ids_data);
         } break;
         case 7: {
             mul_mat_f_switch_ids<T, cols_per_block, 7>(
                 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
+                ids_data);
         } break;
         case 8: {
             mul_mat_f_switch_ids<T, cols_per_block, 8>(
                 x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
-                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
+                sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
+                ids_data);
         } break;
         default: {
             GGML_ABORT("fatal error");
@@ -361,7 +643,7 @@ static void mul_mat_f_switch_cols_per_block(
         const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
         const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
-        cudaStream_t stream) {
+        cudaStream_t stream, const mmf_ids_data * ids_data) {
 
     const int ncols_case = (ids && ncols_dst > 16) ? 16 : ncols_dst;
 
@@ -371,82 +653,82 @@ static void mul_mat_f_switch_cols_per_block(
         case  1: {
             mul_mat_f_cuda<T,  1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case  2: {
             mul_mat_f_cuda<T,  2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case  3: {
             mul_mat_f_cuda<T,  3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case  4: {
             mul_mat_f_cuda<T,  4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case  5: {
             mul_mat_f_cuda<T,  5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y,  stride_sample_dst, stream);
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y,  stride_sample_dst, stream, ids_data);
         } break;
         case  6: {
             mul_mat_f_cuda<T,  6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case  7: {
             mul_mat_f_cuda<T,  7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case  8: {
             mul_mat_f_cuda<T,  8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case  9: {
             mul_mat_f_cuda<T,  9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case 10: {
             mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case 11: {
             mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case 12: {
             mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case 13: {
             mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case 14: {
             mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case 15: {
             mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         case 16: {
             mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
                 stride_col_id, stride_row_id, nchannels_x, nchannels_y,  nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
-                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
+                nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
         } break;
         default: {
             GGML_ABORT("fatal error");
@@ -462,7 +744,7 @@ static void mul_mat_f_switch_cols_per_block(
         const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \
         const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\
         const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \
-        cudaStream_t stream);
+        cudaStream_t stream, const mmf_ids_data * ids_data);
 
 #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
 #define DECL_MMF_CASE_EXTERN(ncols_dst) \
diff --git a/ggml/src/ggml-cuda/mmid.cu b/ggml/src/ggml-cuda/mmid.cu
new file mode 100644 (file)
index 0000000..3c61e45
--- /dev/null
@@ -0,0 +1,164 @@
+#include "common.cuh"
+#include "mmid.cuh"
+
+// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
+struct mm_ids_helper_store {
+    uint32_t data;
+
+    __device__ mm_ids_helper_store(const uint32_t it, const uint32_t iex_used) {
+        data = (it & 0x003FFFFF) | (iex_used << 22);
+    }
+
+    __device__ uint32_t it() const {
+        return data & 0x003FFFFF;
+    }
+
+    __device__ uint32_t iex_used() const {
+        return data >> 22;
+    }
+};
+static_assert(sizeof(mm_ids_helper_store) == 4, "unexpected size for mm_ids_helper_store");
+
+// Helper function for mul_mat_id, converts ids to a more convenient format.
+// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
+// ids_dst describes the same mapping but for the dst tensor.
+// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
+template <int n_expert_used_template>
+__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1)
+static __global__ void mm_ids_helper(
+        const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
+        const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+    const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
+    const int expert = blockIdx.x;
+
+    extern __shared__ char data_mm_ids_helper[];
+    mm_ids_helper_store * store = (mm_ids_helper_store *) data_mm_ids_helper;
+
+    int nex_prev   = 0; // Number of columns for experts with a lower index.
+    int it_compact = 0; // Running index for the compact slice of this expert.
+
+    if constexpr (n_expert_used_template == 0) {
+        // Generic implementation:
+        for (int it = 0; it < n_tokens; ++it) {
+            int iex_used = -1; // The index at which the expert is used, if any.
+            for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) {
+                const int expert_used = ids[it*si1 + iex];
+                nex_prev += expert_used < expert;
+                if (expert_used == expert) {
+                    iex_used = iex;
+                }
+            }
+
+            if (iex_used != -1) {
+                store[it_compact] = mm_ids_helper_store(it, iex_used);
+            }
+
+            if (warp_reduce_any<warp_size>(iex_used != -1)) {
+                it_compact++;
+            }
+        }
+    } else {
+        // Implementation optimized for specific numbers of experts used:
+        static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used");
+        const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2.
+        for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) {
+            const int it = it0 + threadIdx.x / neu_padded;
+
+            const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any.
+            const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ?
+                ids[it*si1 + iex] : INT_MAX;
+            const int iex_used = expert_used == expert ? iex : -1;
+            nex_prev += expert_used < expert;
+
+            // Whether the threads at this token position have used the expert:
+            const int it_compact_add_self = warp_reduce_any<neu_padded>(iex_used != -1);
+
+            // Do a scan over threads at lower token positions in warp to get the correct index for writing data:
+            int it_compact_add_lower = 0;
+#pragma unroll
+            for (int offset = neu_padded; offset < warp_size; offset += neu_padded) {
+                const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size);
+                if (threadIdx.x >= static_cast<unsigned int>(offset)) {
+                    it_compact_add_lower += tmp;
+                }
+            }
+
+            if (iex_used != -1) {
+                store[it_compact + it_compact_add_lower] = mm_ids_helper_store(it, iex_used);
+            }
+
+            // The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
+            it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size);
+        }
+    }
+    nex_prev = warp_reduce_sum<warp_size>(nex_prev);
+
+    for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) {
+        const mm_ids_helper_store store_it = store[itc];
+        const int it       = store_it.it();
+        const int iex_used = store_it.iex_used();
+        ids_src1[nex_prev + itc] = it*sis1          + iex_used % nchannels_y;
+        ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
+    }
+
+    if (threadIdx.x != 0) {
+        return;
+    }
+
+    expert_bounds[expert] = nex_prev;
+
+    if (expert < static_cast<int>(gridDim.x) - 1) {
+        return;
+    }
+
+    expert_bounds[gridDim.x] = nex_prev + it_compact;
+}
+
+template <int n_expert_used_template>
+static void launch_mm_ids_helper(
+        const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
+        const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
+    GGML_ASSERT(n_tokens          < (1 << 22) && "too few bits in mm_ids_helper_store");
+    GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mm_ids_helper_store");
+
+    const int id = ggml_cuda_get_device();
+    const int warp_size = ggml_cuda_info().devices[id].warp_size;
+    const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
+    CUDA_SET_SHARED_MEMORY_LIMIT(mm_ids_helper<n_expert_used_template>, smpbo);
+
+    const dim3 num_blocks(n_experts, 1, 1);
+    const dim3 block_size(warp_size, 1, 1);
+    const size_t nbytes_shared = n_tokens*sizeof(mm_ids_helper_store);
+    GGML_ASSERT(nbytes_shared <= smpbo);
+    mm_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
+        (ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
+}
+
+void ggml_cuda_launch_mm_ids_helper(
+        const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
+        const int n_experts, const int n_tokens, const int n_expert_used, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
+    switch (n_expert_used) {
+        case  2:
+            launch_mm_ids_helper< 2>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
+            break;
+        case  4:
+            launch_mm_ids_helper< 4>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
+            break;
+        case  6:
+            launch_mm_ids_helper< 6>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
+            break;
+        case  8:
+            launch_mm_ids_helper< 8>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
+            break;
+        case 16:
+            launch_mm_ids_helper<16>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
+            break;
+        case 32:
+            launch_mm_ids_helper<32>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
+            break;
+        default:
+            launch_mm_ids_helper< 0>(ids, ids_src1, ids_dst, expert_bounds, n_experts, n_tokens, n_expert_used, nchannels_y, si1, sis1, stream);
+            break;
+    }
+}
diff --git a/ggml/src/ggml-cuda/mmid.cuh b/ggml/src/ggml-cuda/mmid.cuh
new file mode 100644 (file)
index 0000000..ac090ae
--- /dev/null
@@ -0,0 +1,5 @@
+#pragma once
+
+void ggml_cuda_launch_mm_ids_helper(
+        const int32_t * ids, int32_t * ids_src1, int32_t * ids_dst, int32_t * expert_bounds,
+        int n_experts, int n_tokens, int n_expert_used, int nchannels_y, int si1, int sis1, cudaStream_t stream);
index 12bdc629bd6b200837cc238c05c6230490da42bd..a2c8760abea93cdb1e80b098b11bed1d61d53df3 100644 (file)
@@ -1,141 +1,6 @@
 #include "mmq.cuh"
 #include "quantize.cuh"
-
-#include <vector>
-
-// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
-struct mmq_ids_helper_store {
-    uint32_t data;
-
-    __device__ mmq_ids_helper_store(const uint32_t it, const uint32_t iex_used) {
-        data = (it & 0x003FFFFF) | (iex_used << 22);
-    }
-
-    __device__ uint32_t it() const {
-        return data & 0x003FFFFF;
-    }
-
-    __device__ uint32_t iex_used() const {
-        return data >> 22;
-    }
-};
-static_assert(sizeof(mmq_ids_helper_store) == 4, "unexpected size for mmq_ids_helper_store");
-
-// Helper function for mul_mat_id, converts ids to a more convenient format.
-// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
-// ids_dst describes the same mapping but for the dst tensor.
-// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
-template <int n_expert_used_template>
-__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1)
-static __global__ void mmq_ids_helper(
-        const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
-        const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {
-    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
-    const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
-    const int expert = blockIdx.x;
-
-    extern __shared__ char data_mmq_ids_helper[];
-    mmq_ids_helper_store * store = (mmq_ids_helper_store *) data_mmq_ids_helper;
-
-    int nex_prev   = 0; // Number of columns for experts with a lower index.
-    int it_compact = 0; // Running index for the compact slice of this expert.
-
-    if constexpr (n_expert_used_template == 0) {
-        // Generic implementation:
-        for (int it = 0; it < n_tokens; ++it) {
-            int iex_used = -1; // The index at which the expert is used, if any.
-            for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) {
-                const int expert_used = ids[it*si1 + iex];
-                nex_prev += expert_used < expert;
-                if (expert_used == expert) {
-                    iex_used = iex;
-                }
-            }
-
-            if (iex_used != -1) {
-                store[it_compact] = mmq_ids_helper_store(it, iex_used);
-            }
-
-            if (warp_reduce_any<warp_size>(iex_used != -1)) {
-                it_compact++;
-            }
-        }
-    } else {
-        // Implementation optimized for specific numbers of experts used:
-        static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used");
-        const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2.
-        for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) {
-            const int it = it0 + threadIdx.x / neu_padded;
-
-            const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any.
-            const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ?
-                ids[it*si1 + iex] : INT_MAX;
-            const int iex_used = expert_used == expert ? iex : -1;
-            nex_prev += expert_used < expert;
-
-            // Whether the threads at this token position have used the expert:
-            const int it_compact_add_self = warp_reduce_any<neu_padded>(iex_used != -1);
-
-            // Do a scan over threads at lower token positions in warp to get the correct index for writing data:
-            int it_compact_add_lower = 0;
-#pragma unroll
-            for (int offset = neu_padded; offset < warp_size; offset += neu_padded) {
-                const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size);
-                if (threadIdx.x >= static_cast<unsigned int>(offset)) {
-                    it_compact_add_lower += tmp;
-                }
-            }
-
-            if (iex_used != -1) {
-                store[it_compact + it_compact_add_lower] = mmq_ids_helper_store(it, iex_used);
-            }
-
-            // The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
-            it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size);
-        }
-    }
-    nex_prev = warp_reduce_sum<warp_size>(nex_prev);
-
-    for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) {
-        const mmq_ids_helper_store store_it = store[itc];
-        const int it       = store_it.it();
-        const int iex_used = store_it.iex_used();
-        ids_src1[nex_prev + itc] = it*sis1          + iex_used % nchannels_y;
-        ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
-    }
-
-    if (threadIdx.x != 0) {
-        return;
-    }
-
-    expert_bounds[expert] = nex_prev;
-
-    if (expert < static_cast<int>(gridDim.x) - 1) {
-        return;
-    }
-
-    expert_bounds[gridDim.x] = nex_prev + it_compact;
-}
-
-template <int n_expert_used_template>
-static void launch_mmq_ids_helper(
-        const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
-        const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
-    GGML_ASSERT(n_tokens          < (1 << 22) && "too few bits in mmq_ids_helper_store");
-    GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mmq_ids_helper_store");
-
-    const int id = ggml_cuda_get_device();
-    const int warp_size = ggml_cuda_info().devices[id].warp_size;
-    const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
-    CUDA_SET_SHARED_MEMORY_LIMIT(mmq_ids_helper<n_expert_used_template>, smpbo);
-
-    const dim3 num_blocks(n_experts, 1, 1);
-    const dim3 block_size(warp_size, 1, 1);
-    const size_t nbytes_shared = n_tokens*sizeof(mmq_ids_helper_store);
-    GGML_ASSERT(nbytes_shared <= smpbo);
-    mmq_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
-        (ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
-}
+#include "mmid.cuh"
 
 static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
     switch (args.type_x) {
@@ -293,36 +158,8 @@ void ggml_cuda_mul_mat_q(
         const int si1  = ids->nb[1] / ggml_element_size(ids);
         const int sis1 = nb12 / nb11;
 
-        switch (n_expert_used) {
-            case  2:
-                launch_mmq_ids_helper< 2> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
-                    ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
-                break;
-            case  4:
-                launch_mmq_ids_helper< 4> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
-                    ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
-                break;
-            case  6:
-                launch_mmq_ids_helper< 6> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
-                    ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
-                break;
-            case  8:
-                launch_mmq_ids_helper< 8> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
-                    ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
-                break;
-            case 16:
-                launch_mmq_ids_helper<16> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
-                    ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
-                break;
-            case 32:
-                launch_mmq_ids_helper<32> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
-                    ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
-                break;
-            default:
-                launch_mmq_ids_helper< 0> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
-                    ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
-                break;
-        }
+        ggml_cuda_launch_mm_ids_helper((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
+            ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
         CUDA_CHECK(cudaGetLastError());
     }