]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: MoE helper in device code, better tile sizes (#15525)
authorJohannes Gäßler <redacted>
Mon, 25 Aug 2025 15:23:40 +0000 (17:23 +0200)
committerGitHub <redacted>
Mon, 25 Aug 2025 15:23:40 +0000 (17:23 +0200)
* CUDA: MoE helper in device code, better tile sizes

* reduce superfluous CUDA blocks

ggml/src/ggml-cuda/common.cuh
ggml/src/ggml-cuda/mmq.cu
ggml/src/ggml-cuda/mmq.cuh
ggml/src/ggml-cuda/vendors/hip.h

index 767ad83f60eb50942ede80e12ec0fc3947dff381..48de1649cf5fd941c80cf886d90bfac0012baf61 100644 (file)
@@ -420,16 +420,28 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
 
 template<int width = WARP_SIZE>
 static __device__ __forceinline__ int warp_reduce_all(int x) {
-#ifdef GGML_USE_HIP
+    if (width == ggml_cuda_get_physical_warp_size()) {
+        return __all_sync(0xffffffff, x);
+    } else {
 #pragma unroll
-    for (int offset = width/2; offset > 0; offset >>= 1) {
-        x = x && __shfl_xor_sync(0xffffffff, x, offset, width);
+        for (int offset = width/2; offset > 0; offset >>= 1) {
+            x = __shfl_xor_sync(0xffffffff, x, offset, width) && x;
+        }
+        return x;
+    }
+}
+
+template<int width = WARP_SIZE>
+static __device__ __forceinline__ int warp_reduce_any(int x) {
+    if (width == ggml_cuda_get_physical_warp_size()) {
+        return __any_sync(0xffffffff, x);
+    } else {
+#pragma unroll
+        for (int offset = width/2; offset > 0; offset >>= 1) {
+            x = __shfl_xor_sync(0xffffffff, x, offset, width) || x;
+        }
+        return x;
     }
-    return x;
-#else
-    static_assert(width == WARP_SIZE, "width != WARP_SIZE not implemented");
-    return __all_sync(0xffffffff, x);
-#endif // GGML_USE_HIP
 }
 
 template<int width = WARP_SIZE>
index 576032a0ce0dd49c7c1755a3c7bb8cb50b7fda1c..714b23f9f49aaec6bb7ec108a36a61192da3a0b7 100644 (file)
@@ -3,6 +3,140 @@
 
 #include <vector>
 
+// To reduce shared memory use, store "it" and "iex_used" with 22/10 bits each.
+struct mmq_ids_helper_store {
+    uint32_t data;
+
+    __device__ mmq_ids_helper_store(const uint32_t it, const uint32_t iex_used) {
+        data = (it & 0x003FFFFF) | (iex_used << 22);
+    }
+
+    __device__ uint32_t it() const {
+        return data & 0x003FFFFF;
+    }
+
+    __device__ uint32_t iex_used() const {
+        return data >> 22;
+    }
+};
+static_assert(sizeof(mmq_ids_helper_store) == 4, "unexpected size for mmq_ids_helper_store");
+
+// Helper function for mul_mat_id, converts ids to a more convenient format.
+// ids_src1 describes how to permute the flattened column indices of src1 in order to get a compact src1 tensor sorted by expert.
+// ids_dst describes the same mapping but for the dst tensor.
+// The upper and lower bounds for the ith expert in the compact src1 tensor are stored in expert_bounds[i:i+1].
+template <int n_expert_used_template>
+__launch_bounds__(ggml_cuda_get_physical_warp_size(), 1)
+static __global__ void mmq_ids_helper(
+        const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
+        const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1) {
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+    const int n_expert_used = n_expert_used_template == 0 ? n_expert_used_var : n_expert_used_template;
+    const int expert = blockIdx.x;
+
+    extern __shared__ char data_mmq_ids_helper[];
+    mmq_ids_helper_store * store = (mmq_ids_helper_store *) data_mmq_ids_helper;
+
+    int nex_prev   = 0; // Number of columns for experts with a lower index.
+    int it_compact = 0; // Running index for the compact slice of this expert.
+
+    if constexpr (n_expert_used_template == 0) {
+        // Generic implementation:
+        for (int it = 0; it < n_tokens; ++it) {
+            int iex_used = -1; // The index at which the expert is used, if any.
+            for (int iex = threadIdx.x; iex < n_expert_used; iex += warp_size) {
+                const int expert_used = ids[it*si1 + iex];
+                nex_prev += expert_used < expert;
+                if (expert_used == expert) {
+                    iex_used = iex;
+                }
+            }
+
+            if (iex_used != -1) {
+                store[it_compact] = mmq_ids_helper_store(it, iex_used);
+            }
+
+            if (warp_reduce_any<warp_size>(iex_used != -1)) {
+                it_compact++;
+            }
+        }
+    } else {
+        // Implementation optimized for specific numbers of experts used:
+        static_assert(n_expert_used == 6 || warp_size % n_expert_used == 0, "bad n_expert_used");
+        const int neu_padded = n_expert_used == 6 ? 8 : n_expert_used; // Padded to next higher power of 2.
+        for (int it0 = 0; it0 < n_tokens; it0 += warp_size/neu_padded) {
+            const int it = it0 + threadIdx.x / neu_padded;
+
+            const int iex = threadIdx.x % neu_padded; // The index at which the expert is used, if any.
+            const int expert_used = (neu_padded == n_expert_used || iex < n_expert_used) && it < n_tokens ?
+                ids[it*si1 + iex] : INT_MAX;
+            const int iex_used = expert_used == expert ? iex : -1;
+            nex_prev += expert_used < expert;
+
+            // Whether the threads at this token position have used the expert:
+            const int it_compact_add_self = warp_reduce_any<neu_padded>(iex_used != -1);
+
+            // Do a scan over threads at lower token positions in warp to get the correct index for writing data:
+            int it_compact_add_lower = 0;
+#pragma unroll
+            for (int offset = neu_padded; offset < warp_size; offset += neu_padded) {
+                const int tmp = __shfl_up_sync(0xFFFFFFFF, it_compact_add_self, offset, warp_size);
+                if (threadIdx.x >= offset) {
+                    it_compact_add_lower += tmp;
+                }
+            }
+
+            if (iex_used != -1) {
+                store[it_compact + it_compact_add_lower] = mmq_ids_helper_store(it, iex_used);
+            }
+
+            // The thread with the highest index in the warp always has the sum over the whole warp, use it to increment all threads:
+            it_compact += __shfl_sync(0xFFFFFFFF, it_compact_add_lower + it_compact_add_self, warp_size - 1, warp_size);
+        }
+    }
+    nex_prev = warp_reduce_sum<warp_size>(nex_prev);
+
+    for (int itc = threadIdx.x; itc < it_compact; itc += warp_size) {
+        const mmq_ids_helper_store store_it = store[itc];
+        const int it       = store_it.it();
+        const int iex_used = store_it.iex_used();
+        ids_src1[nex_prev + itc] = it*sis1          + iex_used % nchannels_y;
+        ids_dst [nex_prev + itc] = it*n_expert_used + iex_used;
+    }
+
+    if (threadIdx.x != 0) {
+        return;
+    }
+
+    expert_bounds[expert] = nex_prev;
+
+    if (expert < gridDim.x - 1) {
+        return;
+    }
+
+    expert_bounds[gridDim.x] = nex_prev + it_compact;
+}
+
+template <int n_expert_used_template>
+static void launch_mmq_ids_helper(
+        const int32_t * __restrict__ ids, int32_t * __restrict__ ids_src1, int32_t * __restrict__ ids_dst, int32_t * __restrict__ expert_bounds,
+        const int n_experts, const int n_tokens, const int n_expert_used_var, const int nchannels_y, const int si1, const int sis1, cudaStream_t stream) {
+    GGML_ASSERT(n_tokens          < (1 << 22) && "too few bits in mmq_ids_helper_store");
+    GGML_ASSERT(n_expert_used_var < (1 << 10) && "too few bits in mmq_ids_helper_store");
+
+    const int id = ggml_cuda_get_device();
+    const int warp_size = ggml_cuda_info().devices[id].warp_size;
+    const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
+    CUDA_SET_SHARED_MEMORY_LIMIT(mmq_ids_helper<n_expert_used_template>, smpbo);
+
+    const dim3 num_blocks(n_experts, 1, 1);
+    const dim3 block_size(warp_size, 1, 1);
+    const size_t nbytes_shared = n_tokens*sizeof(mmq_ids_helper_store);
+    GGML_ASSERT(nbytes_shared <= smpbo);
+    mmq_ids_helper<n_expert_used_template><<<num_blocks, block_size, nbytes_shared, stream>>>
+        (ids, ids_src1, ids_dst, expert_bounds, n_tokens, n_expert_used_var, nchannels_y, si1, sis1);
+}
+
 static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
     switch (args.type_x) {
         case GGML_TYPE_Q4_0:
@@ -137,7 +271,7 @@ void ggml_cuda_mul_mat_q(
             ne00, ne01, ne1, s01, ne11, s1,
             ne02, ne12, s02, s12, s2,
             ne03, ne13, s03, s13, s3,
-            use_stream_k};
+            use_stream_k, ne1};
         ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
         return;
     }
@@ -148,53 +282,49 @@ void ggml_cuda_mul_mat_q(
 
     const int64_t n_expert_used = ids->ne[0];
     const int64_t ne_get_rows = ne12 * n_expert_used;
+    GGML_ASSERT(ne1 == n_expert_used);
 
-    std::vector<char> ids_host(ggml_nbytes(ids));
-    std::vector<int32_t> ids_src1_host;
-    ids_src1_host.reserve(ne_get_rows);
-    std::vector<int32_t> ids_dst_host;
-    ids_dst_host.reserve(ne_get_rows);
-    std::vector<int32_t> tokens_per_expert_host(ne02);
-    std::vector<int32_t> expert_bounds_host(ne02 + 1);
-    ggml_cuda_pool_alloc<int32_t> ids_buf_dev(ctx.pool());
-
-    CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids->data, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
-    CUDA_CHECK(cudaStreamSynchronize(stream));
-
-    for (int64_t i02 = 0; i02 < ne02; ++i02) { // expert matrices
-        for (int64_t i12 = 0; i12 < ne12; ++i12) { // tokens
-            for (int64_t iex = 0; iex < n_expert_used; ++iex) {
-                const int32_t expert_to_use = *(const int32_t *)(ids_host.data() + i12*ids->nb[1] + iex*ids->nb[0]);
-                assert(expert_to_use >= 0 && expert_to_use < ne02);
-                if (expert_to_use == i02) {
-                    ids_src1_host.push_back(i12*(nb12/nb11) + iex % ne11);
-                    ids_dst_host.push_back(i12*ne1 + iex);
-                    tokens_per_expert_host[i02]++;
-                    break;
-                }
-            }
-        }
-    }
+    ggml_cuda_pool_alloc<int32_t> ids_src1(ctx.pool(), ne_get_rows);
+    ggml_cuda_pool_alloc<int32_t> ids_dst(ctx.pool(), ne_get_rows);
+    ggml_cuda_pool_alloc<int32_t> expert_bounds(ctx.pool(), ne02 + 1);
 
-    int32_t cumsum = 0;
-    for (int64_t i = 0; i < ne02; ++i) {
-        expert_bounds_host[i] = cumsum;
-        cumsum += tokens_per_expert_host[i];
+    {
+        GGML_ASSERT(ids->nb[0] == ggml_element_size(ids));
+        const int si1  = ids->nb[1] / ggml_element_size(ids);
+        const int sis1 = nb12 / nb11;
+
+        switch (n_expert_used) {
+            case  2:
+                launch_mmq_ids_helper< 2> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
+                    ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
+                break;
+            case  4:
+                launch_mmq_ids_helper< 4> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
+                    ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
+                break;
+            case  6:
+                launch_mmq_ids_helper< 6> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
+                    ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
+                break;
+            case  8:
+                launch_mmq_ids_helper< 8> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
+                    ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
+                break;
+            case 16:
+                launch_mmq_ids_helper<16> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
+                    ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
+                break;
+            case 32:
+                launch_mmq_ids_helper<32> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
+                    ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
+                break;
+            default:
+                launch_mmq_ids_helper< 0> ((const int32_t *) ids->data, ids_src1.get(), ids_dst.get(), expert_bounds.get(),
+                    ne02, ne12, n_expert_used, ne11, si1, sis1, stream);
+                break;
+        }
+        CUDA_CHECK(cudaGetLastError());
     }
-    expert_bounds_host[ne02] = cumsum;
-
-    std::vector<int32_t> ids_buf_host;
-    ids_buf_host.reserve(ids_src1_host.size() + ids_dst_host.size() + expert_bounds_host.size());
-    ids_buf_host.insert(ids_buf_host.end(), ids_src1_host.begin(), ids_src1_host.end());
-    ids_buf_host.insert(ids_buf_host.end(), ids_dst_host.begin(), ids_dst_host.end());
-    ids_buf_host.insert(ids_buf_host.end(), expert_bounds_host.begin(), expert_bounds_host.end());
-    ids_buf_dev.alloc(ids_buf_host.size() + get_mmq_x_max_host(cc)); // Expert bounds are padded on device.
-    CUDA_CHECK(cudaMemcpyAsync(ids_buf_dev.ptr, ids_buf_host.data(), ids_buf_host.size()*sizeof(int32_t), cudaMemcpyHostToDevice, stream));
-    CUDA_CHECK(cudaStreamSynchronize(stream));
-
-    const int32_t * ids_src1_dev      = ids_buf_dev.ptr;
-    const int32_t * ids_dst_dev       = ids_src1_dev + ids_src1_host.size();
-    const int32_t * expert_bounds_dev = ids_dst_dev + ids_dst_host.size();
 
     const size_t nbytes_src1_q8_1 = ne12*n_expert_used*ne10_padded * sizeof(block_q8_1)/QK8_1 +
         get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
@@ -208,7 +338,7 @@ void ggml_cuda_mul_mat_q(
         const int64_t s11 = src1->nb[1] / ts_src1;
         const int64_t s12 = src1->nb[2] / ts_src1;
         const int64_t s13 = src1->nb[2] / ts_src1;
-        quantize_mmq_q8_1_cuda(src1_d, ids_src1_dev, src1_q8_1.get(), src0->type,
+        quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type,
             ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
         CUDA_CHECK(cudaGetLastError());
     }
@@ -218,11 +348,11 @@ void ggml_cuda_mul_mat_q(
 
     // Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
     const mmq_args args = {
-        src0_d, src0->type, (const int *) src1_q8_1.ptr, ids_dst_dev, expert_bounds_dev, dst_d,
+        src0_d, src0->type, (const int *) src1_q8_1.get(), ids_dst.get(), expert_bounds.get(), dst_d,
         ne00, ne01, ne_get_rows, s01, ne_get_rows, s1,
         ne02, ne02, s02, s12, s2,
         ne03, ne13, s03, s13, s3,
-        use_stream_k};
+        use_stream_k, ne12};
 
     ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
 }
@@ -262,7 +392,7 @@ void ggml_cuda_op_mul_mat_q(
         ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst,
         1, 1, 0, 0, 0,
         1, 1, 0, 0, 0,
-        use_stream_k};
+        use_stream_k, src1_ncols};
 
     ggml_cuda_mul_mat_q_switch_type(ctx, args, stream);
 
index 650f7080677ad9f0bb795142e29ddffbf0eaecd9..c9a07e82fedf2e8fc85562179b426414c6e937ca 100644 (file)
@@ -3138,7 +3138,8 @@ static __global__ void mul_mat_q(
         const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
         const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst,
         const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
-        const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
+        const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
+        const int ncols_max) {
 
     // Skip unused template specializations for faster compilation:
     if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {
@@ -3152,7 +3153,7 @@ static __global__ void mul_mat_q(
     constexpr int qk    = ggml_cuda_type_traits<type>::qk;
     constexpr int mmq_y = get_mmq_y_device();
 
-    const int ntx = (ncols_dst + mmq_x - 1) / mmq_x; // Number of tiles x
+    const int ntx = (ncols_max + mmq_x - 1) / mmq_x; // Number of tiles x
     const int nty = (nrows_x   + mmq_y - 1) / mmq_y; // Number of tiles y
 
     // Initialize the ids for writing back data with just the index.
@@ -3376,7 +3377,8 @@ template <ggml_type type, int mmq_x, bool need_check>
 static __global__ void mul_mat_q_stream_k_fixup(
         const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
         const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst,
-        const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst) {
+        const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst,
+        const int ncols_max) {
     constexpr int     mmq_y           = get_mmq_y_device();
     constexpr int     qk              = ggml_cuda_type_traits<type>::qk;
     constexpr int     blocks_per_iter = MMQ_ITER_K / qk;
@@ -3387,7 +3389,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
 
     float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
 
-    const int ntx  = (ncols_dst + mmq_x - 1) / mmq_x;
+    const int ntx  = (ncols_max + mmq_x - 1) / mmq_x;
     const int nty  = (nrows_x   + mmq_y - 1) / mmq_y;
 
     const int bidx0 = blockIdx.x;
@@ -3528,7 +3530,7 @@ struct mmq_args {
     int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst;
     int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst;
     int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst;
-    bool use_stream_k;
+    bool use_stream_k; int64_t ncols_max;
 };
 
 template<ggml_type type>
@@ -3558,7 +3560,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
     CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q<type, mmq_x,  true>), nbytes_shared);
 
     const int nty  = (args.nrows_x   + mmq_y - 1) / mmq_y;
-    const int ntx  = (args.ncols_dst + mmq_x - 1) / mmq_x;
+    const int ntx  = (args.ncols_max + mmq_x - 1) / mmq_x;
     const int ntzw = args.nchannels_y * args.nsamples_y;
     const dim3 block_nums_xy_tiling(nty, ntx, ntzw);
 
@@ -3574,14 +3576,16 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
                 (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
                  args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
                  channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
-                 sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
+                 sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
+                 args.ncols_max);
         } else {
             constexpr bool need_check = true;
             mul_mat_q<type, mmq_x, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
                 (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
                  args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
                  channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
-                 sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
+                 sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
+                 args.ncols_max);
         }
         return;
     }
@@ -3601,7 +3605,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
             (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
              args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
              channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
-             sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
+             sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
+             args.ncols_max);
 
         if (!fixup_needed) {
             return;
@@ -3609,14 +3614,16 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
 
         mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
             (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
-             args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
+             args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
+             args.ncols_max);
     } else {
         constexpr bool need_check = true;
         mul_mat_q<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
             (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
              args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
              channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
-             sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
+             sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst,
+             args.ncols_max);
 
         if (!fixup_needed) {
             return;
@@ -3624,7 +3631,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
 
         mul_mat_q_stream_k_fixup<type, mmq_x, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
             (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
-             args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
+             args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst,
+             args.ncols_max);
     }
 }
 
@@ -3649,7 +3657,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
             continue;
         }
 
-        const int ntiles_x = (args.ncols_y + mmq_x - 1) / mmq_x;
+        const int ntiles_x = (args.ncols_max + mmq_x - 1) / mmq_x;
 
         if (ntiles_x < ntiles_x_best) {
             mmq_x_best = mmq_x;
index 6e9c67aca096e68c136c2103b5d178fff603ad60..c6a33d5de310f632523ab975c9060cfeb2b944ab 100644 (file)
 #define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
 #define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
 #define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
+#define __shfl_up_sync(mask, var, laneMask, width) __shfl_up(var, laneMask, width)
 #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
+#define __all_sync(mask, var) __all(var)
+#define __any_sync(mask, var) __any(var)
 #define cublasCreate hipblasCreate
 #define cublasDestroy hipblasDestroy
 #define cublasGemmEx hipblasGemmEx