]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: Implement topk_moe fused shader, ported from CUDA (#16641)
authorJeff Bolz <redacted>
Sat, 18 Oct 2025 10:22:57 +0000 (05:22 -0500)
committerGitHub <redacted>
Sat, 18 Oct 2025 10:22:57 +0000 (12:22 +0200)
This is similar to the CUDA shader from #16130, but doesn't use shared memory
and handles different subgroup sizes.

ggml/src/ggml-impl.h
ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp [new file with mode: 0644]
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

index d0fb3bccad2250d99d97f70484a0577d5aef4176..18f095b8960104075d8a87cc782af5a05d088430 100644 (file)
@@ -565,14 +565,23 @@ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
 #define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
 #define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
 
+static inline int32_t ggml_node_get_use_count(const struct ggml_cgraph * cgraph, int node_idx) {
+    const struct ggml_tensor * node = cgraph->nodes[node_idx];
+
+    size_t hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
+    if (!ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos)) {
+        return 0;
+    }
+    return cgraph->use_counts[hash_pos];
+}
+
 // return true if the node's results are only used by N other nodes
 // and can be fused into their calculations.
 static inline bool ggml_node_has_n_uses(const struct ggml_cgraph * cgraph, int node_idx, int32_t n_uses) {
     const struct ggml_tensor * node = cgraph->nodes[node_idx];
 
     // check the use count against how many we're replacing
-    size_t hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
-    if (!ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos) || cgraph->use_counts[hash_pos] != n_uses) {
+    if (ggml_node_get_use_count(cgraph, node_idx) != n_uses) {
         return false;
     }
 
index bc703611f0a94a11516a574e8dd6e1d1970cfc75..21bd0522555643e5806e7f6d3eb2a34fcf350487 100644 (file)
@@ -385,6 +385,14 @@ enum shader_reduction_mode {
 
 static constexpr uint32_t num_argsort_pipelines = 11;
 static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
+static constexpr uint32_t num_topk_moe_pipelines = 10;
+
+static constexpr std::array topk_moe_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE,  GGML_OP_ARGSORT,
+                                           GGML_OP_VIEW,     GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
+                                           GGML_OP_SUM_ROWS, GGML_OP_DIV,      GGML_OP_RESHAPE };
+static constexpr std::array topk_moe     { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE,  GGML_OP_ARGSORT,
+                                           GGML_OP_VIEW,     GGML_OP_GET_ROWS };
+
 
 struct vk_device_struct {
     std::recursive_mutex mutex;
@@ -598,6 +606,9 @@ struct vk_device_struct {
 
     vk_pipeline pipeline_flash_attn_split_k_reduce;
 
+    // [2] is {!norm, norm}
+    vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2];
+
     std::vector<vk_pipeline_ref> all_pipelines;
 
     std::vector<std::tuple<void*, size_t, vk_buffer>> pinned_memory;
@@ -941,6 +952,11 @@ struct vk_op_multi_add_push_constants {
 static_assert(MAX_PARAMETER_COUNT == 12);
 static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
 
+struct vk_op_topk_moe_push_constants {
+    uint32_t n_rows;
+    uint32_t n_expert_used;
+};
+
 struct vk_op_add_id_push_constants {
     uint32_t ne0;
     uint32_t ne1;
@@ -3722,6 +3738,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
 
+    for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
+        ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][0], "topk_moe_f32_"+std::to_string(i),   topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0}, 1, true, true);
+        ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][1], "topk_moe_f32_"+std::to_string(i),   topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1}, 1, true, true);
+    }
+
     for (auto &c : compiles) {
         c.wait();
     }
@@ -8004,6 +8025,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
         GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
         GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32);
 
+        if (ctx->num_additional_fused_ops) {
+            uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
+            GGML_ASSERT(idx < num_topk_moe_pipelines);
+            bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
+            return ctx->device->pipeline_topk_moe[idx][with_norm];
+        }
+
         if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
             return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
         }
@@ -9589,6 +9617,87 @@ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& sub
     ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] }, dryrun);
 }
 
+static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) {
+
+    bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
+    ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
+    ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
+    ggml_tensor * ids = cgraph->nodes[node_idx + 3];
+
+    GGML_ASSERT(logits->type == GGML_TYPE_F32);
+    GGML_ASSERT(weights->type == GGML_TYPE_F32);
+    GGML_ASSERT(ids->type == GGML_TYPE_I32);
+
+    const int n_experts = logits->ne[0];
+    const int n_rows    = logits->ne[1];
+    const int n_expert_used = weights->ne[1];
+
+    GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
+
+    vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, nullptr, nullptr, nullptr, cgraph->nodes[node_idx], GGML_OP_SOFT_MAX);
+
+    if (dryrun) {
+        ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
+        return;
+    }
+
+    ggml_backend_vk_buffer_context * logits_buf_ctx = (ggml_backend_vk_buffer_context *)logits->buffer->context;
+    ggml_backend_vk_buffer_context * weights_buf_ctx = (ggml_backend_vk_buffer_context *)weights->buffer->context;
+    ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context;
+
+    vk_buffer d_logits = nullptr;
+    size_t logits_buf_offset = 0;
+    vk_buffer d_weights = nullptr;
+    size_t weights_buf_offset = 0;
+    vk_buffer d_ids = nullptr;
+    size_t ids_buf_offset = 0;
+
+    bool logits_uma = false;
+    bool weights_uma = false;
+    bool ids_uma = false;
+
+    if (ctx->device->uma) {
+        ggml_vk_host_get(ctx->device, logits->data, d_logits, logits_buf_offset);
+        ggml_vk_host_get(ctx->device, weights->data, d_weights, weights_buf_offset);
+        ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset);
+        logits_uma = d_logits != nullptr;
+        weights_uma = d_weights != nullptr;
+        ids_uma = d_ids != nullptr;
+    }
+
+    if (!logits_uma) {
+        d_logits = logits_buf_ctx->dev_buffer;
+        logits_buf_offset = vk_tensor_offset(logits) + logits->view_offs;
+        GGML_ASSERT(d_logits != nullptr);
+    }
+    if (!weights_uma) {
+        d_weights = weights_buf_ctx->dev_buffer;
+        weights_buf_offset = vk_tensor_offset(weights) + weights->view_offs;
+        GGML_ASSERT(d_weights != nullptr);
+    }
+    if (!ids_uma) {
+        d_ids = ids_buf_ctx->dev_buffer;
+        ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs;
+        GGML_ASSERT(d_ids != nullptr);
+    }
+
+    vk_op_topk_moe_push_constants pc;
+    pc.n_rows = n_rows;
+    pc.n_expert_used = n_expert_used;
+
+    GGML_ASSERT(n_expert_used <= n_experts);
+
+    const uint32_t rows_per_block = 4;
+    std::array<uint32_t, 3> elements = { CEIL_DIV(n_rows, rows_per_block), 1, 1 };
+
+    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
+        {
+            ggml_vk_subbuffer(ctx, d_logits, logits_buf_offset),
+            ggml_vk_subbuffer(ctx, d_weights, weights_buf_offset),
+            ggml_vk_subbuffer(ctx, d_ids, ids_buf_offset),
+        }, pc, elements);
+}
+
 static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) {
     const int n_dims        = ((int32_t *) dst->op_params)[1];
     const int mode          = ((int32_t *) dst->op_params)[2];
@@ -11174,11 +11283,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
             ctx->unsynced_nodes_read.clear();
             ggml_vk_sync_buffers(ctx, compute_ctx);
         }
-        // Add the last fused node and all fused source nodes to the unsynchronized list.
-        const ggml_tensor * last_node = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
-        ctx->unsynced_nodes_written.push_back(last_node);
+        // Add all fused nodes to the unsynchronized lists.
         for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
             const ggml_tensor *cur_node = cgraph->nodes[node_idx + i];
+            // Multiple outputs could be written, e.g. in topk_moe. Add them all to the list.
+            ctx->unsynced_nodes_written.push_back(cur_node);
             for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) {
                 if (!cur_node->src[j]) {
                     continue;
@@ -11345,7 +11454,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
 
         break;
     case GGML_OP_SOFT_MAX:
-        ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node, dryrun);
+        if (ctx->num_additional_fused_ops) {
+            ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx, dryrun);
+        } else {
+            ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node, dryrun);
+        }
 
         break;
     case GGML_OP_SOFT_MAX_BACK:
@@ -12141,6 +12254,120 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st
     return true;
 }
 
+static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
+                                      int node_idx, bool with_norm) {
+
+    if (with_norm) {
+        if (node_idx + (int)topk_moe_norm.size() > cgraph->n_nodes) {
+            return false;
+        }
+        for (size_t i = 0; i < topk_moe_norm.size(); ++i) {
+            if (cgraph->nodes[node_idx + i]->op != topk_moe_norm[i]) {
+                return false;
+            }
+        }
+    } else {
+        if (node_idx + (int)topk_moe.size() > cgraph->n_nodes) {
+            return false;
+        }
+        for (size_t i = 0; i < topk_moe.size(); ++i) {
+            if (cgraph->nodes[node_idx + i]->op != topk_moe[i]) {
+                return false;
+            }
+        }
+    }
+
+    const ggml_tensor * softmax =  cgraph->nodes[node_idx + 0];
+    const ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
+
+    const float * op_params = (const float *)softmax->op_params;
+
+    float scale = op_params[0];
+    float max_bias = op_params[1];
+
+    if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
+        return false;
+    }
+
+    if (scale != 1.0f || max_bias != 0.0f) {
+        return false;
+    }
+
+    // don't fuse when masks or sinks are present
+    if (softmax->src[1] || softmax->src[2]) {
+        return false;
+    }
+
+    const int n_expert = softmax->ne[0];
+    // n_expert must be a power of 2
+    if (!is_pow2(n_expert) || n_expert > (1 << (num_topk_moe_pipelines-1))) {
+        return false;
+    }
+
+    // Check that the nodes don't have any unexpected uses
+    const ggml_tensor * reshape1 =  cgraph->nodes[node_idx + 1];
+    const ggml_tensor * argsort =   cgraph->nodes[node_idx + 2];
+    const ggml_tensor * view =      cgraph->nodes[node_idx + 3];
+    const ggml_tensor * get_rows =  cgraph->nodes[node_idx + 4];
+    const ggml_tensor * reshape5 =  with_norm ? cgraph->nodes[node_idx + 5] : nullptr;
+    const ggml_tensor * sum_rows =  with_norm ? cgraph->nodes[node_idx + 6] : nullptr;
+    const ggml_tensor * div =       with_norm ? cgraph->nodes[node_idx + 7] : nullptr;
+    const ggml_tensor * reshape8 =  with_norm ? cgraph->nodes[node_idx + 8] : nullptr;
+
+    // softmax is used by reshape and argsort
+    if (ggml_node_get_use_count(cgraph, node_idx) != 2 ||
+        reshape1->src[0] != softmax ||
+        argsort->src[0] != softmax) {
+        return false;
+    }
+    // reshape is used by get_rows
+    if (ggml_node_get_use_count(cgraph, node_idx + 1) != 1 ||
+        get_rows->src[0] != reshape1) {
+        return false;
+    }
+    // argsort is used by view
+    if (ggml_node_get_use_count(cgraph, node_idx + 2) != 1 ||
+        view->src[0] != argsort) {
+        return false;
+    }
+    // view is written (via argsort), we can skip checking it
+
+    if (with_norm) {
+        // get_rows is used by reshape
+        if (ggml_node_get_use_count(cgraph, node_idx + 4) != 1 ||
+            reshape5->src[0] != get_rows) {
+            return false;
+        }
+
+        // reshape is used by sum_rows and div
+        if (ggml_node_get_use_count(cgraph, node_idx + 5) != 2 ||
+            sum_rows->src[0] != reshape5 ||
+            div->src[0] != reshape5) {
+            return false;
+        }
+
+        // sum_rows is used by div
+        if (ggml_node_get_use_count(cgraph, node_idx + 6) != 1 ||
+            div->src[1] != sum_rows) {
+            return false;
+        }
+
+        // div/reshape are written
+        if (reshape8->src[0] != div) {
+            return false;
+        }
+    }
+
+    if (!ctx->device->subgroup_arithmetic ||
+        !ctx->device->subgroup_shuffle ||
+        !ctx->device->subgroup_require_full_support ||
+        ctx->device->disable_fusion) {
+        return false;
+    }
+
+    return true;
+}
+
 static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) {
 
     const ggml_tensor *first_node = cgraph->nodes[node_idx];
@@ -12216,6 +12443,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
                 ctx->num_additional_fused_ops = num_adds - 1;
             } else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
                 ctx->num_additional_fused_ops = 1;
+            } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
+                ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
+            } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
+                ctx->num_additional_fused_ops = topk_moe.size() - 1;
             }
         }
         ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
@@ -12313,6 +12544,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
                 ctx->num_additional_fused_ops = num_adds - 1;
             } else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
                 ctx->num_additional_fused_ops = 1;
+            } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
+                ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
+            } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
+                ctx->num_additional_fused_ops = topk_moe.size() - 1;
             }
         }
 
@@ -12320,10 +12555,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
         bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
         bool submit = (submitted_nodes >= nodes_per_submit) ||
                       (mul_mat_bytes >= mul_mat_bytes_per_submit) ||
-                      (i + ctx->num_additional_fused_ops == last_node) ||
+                      (i + ctx->num_additional_fused_ops >= last_node) ||
                       (almost_ready && !ctx->almost_ready_fence_pending);
 
-        bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops == last_node, almost_ready, submit);
+        bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops >= last_node, almost_ready, submit);
 
         if (vk_perf_logger_enabled) {
             if (ctx->compute_ctx.expired()) {
@@ -12444,6 +12679,25 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
     while (first_unused < graph->n_nodes) {
         std::vector<int> current_set;
 
+        // Avoid reordering topk_moe_norm
+        if (first_unused + (int)topk_moe_norm.size() <= graph->n_nodes) {
+            bool is_topk_moe_norm = true;
+            for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
+                if (graph->nodes[first_unused + j]->op != topk_moe_norm[j] || used[first_unused + j]) {
+                    is_topk_moe_norm = false;
+                }
+            }
+            if (is_topk_moe_norm) {
+                for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
+                    new_order.push_back(graph->nodes[first_unused + j]);
+                    used[first_unused + j] = true;
+                }
+                while (first_unused < graph->n_nodes && used[first_unused]) {
+                    first_unused++;
+                }
+                continue;
+            }
+        }
         // First, grab the next unused node.
         current_set.push_back(first_unused);
 
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp
new file mode 100644 (file)
index 0000000..9e56d5f
--- /dev/null
@@ -0,0 +1,139 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : require
+#extension GL_KHR_shader_subgroup_basic : enable
+#extension GL_KHR_shader_subgroup_arithmetic : enable
+#extension GL_KHR_shader_subgroup_shuffle : enable
+
+#include "types.glsl"
+
+layout (push_constant) uniform parameter
+{
+    uint n_rows;
+    uint n_expert_used;
+};
+
+layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
+
+layout(constant_id = 0) const uint WARP_SIZE = 32;
+layout(constant_id = 1) const uint n_experts = 512;
+layout(constant_id = 2) const bool with_norm = true;
+
+const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
+
+layout (binding = 0, std430) readonly buffer Logits {float logits[];};
+layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
+layout (binding = 2, std430) writeonly buffer Ids {uint ids[];};
+
+void main() {
+    const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
+    if (row >= n_rows) {
+        return;
+    }
+
+    const uint logits_offset = n_experts * row;
+    const uint weights_offset = n_expert_used * row;
+    const uint ids_offset = n_experts * row;
+
+    float logits_r[experts_per_thread];
+
+    const float INFINITY = 1.0 / 0.0;
+
+    [[unroll]]
+    for (uint i = 0; i < n_experts; i += WARP_SIZE) {
+        const uint expert        = i + gl_LocalInvocationID.x;
+        logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[logits_offset + expert] : -INFINITY;
+    }
+
+    float max_val = logits_r[0];
+
+    [[unroll]]
+    for (int i = 1; i < experts_per_thread; i++) {
+        const float val = logits_r[i];
+        max_val         = max(val, max_val);
+    }
+
+    max_val = subgroupMax(max_val);
+
+    float wt[experts_per_thread];
+    float tmp = 0.f;
+
+    [[unroll]]
+    for (int i = 0; i < experts_per_thread; i++) {
+        const float val = logits_r[i];
+        wt[i]           = exp(val - max_val);
+        tmp += wt[i];
+    }
+
+    tmp = subgroupAdd(tmp);
+
+    const float inv_sum = 1.0f / tmp;
+
+    [[unroll]]
+    for (int i = 0; i < experts_per_thread; i++) {
+        wt[i] = wt[i] * inv_sum;
+    }
+
+    // at this point, each thread holds a portion of softmax,
+    // we do the argmax reduce over n_expert_used, each time marking
+    // the expert weight as -inf to exclude from the next iteration
+
+    float wt_sum = 0.f;
+
+    float output_weights[experts_per_thread];
+
+    for (int k = 0; k < n_expert_used; k++) {
+        float max_val    = wt[0];
+        uint   max_expert = gl_LocalInvocationID.x;
+
+        [[unroll]]
+        for (int i = 1; i < experts_per_thread; i++) {
+            const uint expert = gl_LocalInvocationID.x + i * WARP_SIZE;
+            if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
+                max_val    = wt[i];
+                max_expert = expert;
+            }
+        }
+
+        [[unroll]]
+        for (uint mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
+            const float val    = subgroupShuffleXor(max_val, mask);
+            const uint  expert = subgroupShuffleXor(max_expert, mask);
+            if (val > max_val || (val == max_val && expert < max_expert)) {
+                max_val    = val;
+                max_expert = expert;
+            }
+        }
+
+        if ((k & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) {
+            output_weights[k / WARP_SIZE] = max_val;
+        }
+
+        if ((max_expert & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) {
+            wt[max_expert / WARP_SIZE] = -INFINITY;
+
+            ids[ids_offset + k] = max_expert;
+            if (with_norm) {
+                wt_sum += max_val;
+            }
+        }
+    }
+
+    if (with_norm) {
+        wt_sum              = subgroupAdd(wt_sum);
+        const float inv_sum = 1.0f / wt_sum;
+
+        [[unroll]]
+        for (uint i = 0; i < experts_per_thread; ++i) {
+            output_weights[i] *= inv_sum;
+        }
+    }
+
+    [[unroll]]
+    for (uint i = 0; i < experts_per_thread; ++i) {
+        uint idx = i * WARP_SIZE + gl_LocalInvocationID.x;
+        if (idx < n_expert_used) {
+            weights[weights_offset + idx] = output_weights[i];
+        }
+    }
+}
index 1d04a812a038a3f5f88f1bf19b1783fa272bcf90..49bf6c764f726efd1d86dffc20395795f579a34e 100644 (file)
@@ -920,6 +920,8 @@ void process_shaders() {
 
     string_to_spv("ssm_conv_f32", "ssm_conv.comp", {{"A_TYPE", "float"}});
 
+    string_to_spv("topk_moe_f32", "topk_moe.comp", {});
+
     for (auto &c : compiles) {
         c.wait();
     }