]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: Update topk_moe fusion to handle gpt's late softmax (#16656)
authorJeff Bolz <redacted>
Wed, 29 Oct 2025 13:44:29 +0000 (08:44 -0500)
committerGitHub <redacted>
Wed, 29 Oct 2025 13:44:29 +0000 (14:44 +0100)
* vulkan: Update topk_moe fusion to handle gpt's late softmax

Based on #16649.

* Add ggml_check_edges

* Add sync logging to show fusion effects

* handle clamp added in #16655

* Update ggml/src/ggml-impl.h

Co-authored-by: Diego Devesa <redacted>
ggml/src/ggml-impl.h
ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp

index e9201cdc685dcd5efc8aa1b0554acbe0d82a1a3d..ec37a25337b649a94d12fc456beef1cf17de4702 100644 (file)
@@ -682,6 +682,7 @@ static inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph * cgraph,
 #endif
 
 #ifdef __cplusplus
+#include <array>
 #include <initializer_list>
 #include <vector>
 
@@ -697,6 +698,21 @@ inline bool ggml_can_fuse_subgraph(const struct ggml_cgraph *          cgraph,
     return ggml_can_fuse_subgraph(cgraph, start_idx, ops.size(), ops.begin(), outputs.begin(), outputs.size());
 }
 
+// Return true if the edges in the graph match expectations.
+inline bool ggml_check_edges(const struct ggml_cgraph *                cgraph,
+                             int                                       start_idx,
+                             std::initializer_list<std::array<int, 3>> edges) {
+    for (const auto & edge : edges) {
+        int dst_node = edge[0];
+        int src_idx  = edge[1];
+        int src_node = edge[2];
+        if (cgraph->nodes[start_idx + dst_node]->src[src_idx] != cgraph->nodes[start_idx + src_node]) {
+            return false;
+        }
+    }
+    return true;
+}
+
 // expose GGUF internals for test code
 GGML_API size_t gguf_type_size(enum gguf_type type);
 GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
index 3d10aa07b0885c9a55d248f21d11df3df4748d9a..50e7922dc604cc720baf9a26d545ddcb6a8ec964 100644 (file)
@@ -385,12 +385,76 @@ static constexpr uint32_t num_argsort_pipelines = 11;
 static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1);
 static constexpr uint32_t num_topk_moe_pipelines = 10;
 
-static constexpr std::array topk_moe_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE,  GGML_OP_ARGSORT,
-                                           GGML_OP_VIEW,     GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
-                                           GGML_OP_SUM_ROWS, GGML_OP_DIV,      GGML_OP_RESHAPE };
-static constexpr std::array topk_moe     { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE,  GGML_OP_ARGSORT,
-                                           GGML_OP_VIEW,     GGML_OP_GET_ROWS };
+static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGML_OP_SOFT_MAX, GGML_OP_RESHAPE,  GGML_OP_ARGSORT,
+                                                                             GGML_OP_VIEW,     GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
+                                                                             GGML_OP_SUM_ROWS, GGML_OP_CLAMP,    GGML_OP_DIV,
+                                                                             GGML_OP_RESHAPE };
+static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax     { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE,  GGML_OP_ARGSORT,
+                                                                             GGML_OP_VIEW,     GGML_OP_GET_ROWS };
+static constexpr std::initializer_list<ggml_op> topk_moe_late_softmax      { GGML_OP_ARGSORT,  GGML_OP_VIEW,
+                                                                             GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
+                                                                             GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
+
+//node #978 (  SOFT_MAX):     ffn_moe_probs-15 (   0K) [Vulka         ] use=2:    ffn_moe_logits-15 (   0K) [Vulka         ]
+//node #979 (   RESHAPE): ffn_moe_probs-15 (re (   0K) [Vulka         ] use=1:     ffn_moe_probs-15 (   0K) [Vulka         ]
+//node #980 (   ARGSORT):   ffn_moe_argsort-15 (   0K) [Vulka         ] use=1:     ffn_moe_probs-15 (   0K) [Vulka         ]
+//node #981 (      VIEW):      ffn_moe_topk-15 (   0K) [Vulka         ] use=4:   ffn_moe_argsort-15 (   0K) [Vulka         ]
+//node #982 (  GET_ROWS):   ffn_moe_weights-15 (   0K) [Vulka         ] use=1: ffn_moe_probs-15 (re (   0K) [Vulka         ]      ffn_moe_topk-15 (   0K) [Vulka         ]
+//node #983 (   RESHAPE): ffn_moe_weights-15 ( (   0K) [Vulka         ] use=2:   ffn_moe_weights-15 (   0K) [Vulka         ]
+//node #984 (  SUM_ROWS): ffn_moe_weights_sum- (   0K) [Vulka         ] use=1: ffn_moe_weights-15 ( (   0K) [Vulka         ]
+//node #985 (     CLAMP): ffn_moe_weights_sum_ (   0K) [Vulka         ] use=1: ffn_moe_weights_sum- (   0K) [Vulka         ]
+//node #986 (       DIV): ffn_moe_weights_norm (   0K) [Vulka         ] use=1: ffn_moe_weights-15 ( (   0K) [Vulka         ] ffn_moe_weights_sum_ (   0K) [Vulka         ]
+//node #987 (   RESHAPE): ffn_moe_weights_norm (   0K) [Vulka         ] use=1: ffn_moe_weights_norm (   0K) [Vulka         ]
+static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_norm_edges {
+    { 1, 0, 0 }, // reshape->src[0]  == softmax
+    { 2, 0, 0 }, // argsort->src[0]  == softmax
+    { 3, 0, 2 }, // view->src[0]     == argsort
+    { 4, 0, 1 }, // get_rows->src[0] == reshape
+    { 4, 1, 3 }, // get_rows->src[1] == view
+    { 5, 0, 4 }, // reshape->src[0]  == get_rows
+    { 6, 0, 5 }, // sum_rows->src[0] == reshape
+    { 7, 0, 6 }, // clamp->src[0]    == sum_rows
+    { 8, 0, 5 }, // div->src[0]      == reshape
+    { 8, 1, 7 }, // div->src[1]      == clamp
+    { 9, 0, 8 }, // reshape->src[0]  == div
+};
+
+// same as early_softmax_norm but ending after the get_rows
+static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_edges {
+    { 1, 0, 0 }, // reshape->src[0]  == softmax
+    { 2, 0, 0 }, // argsort->src[0]  == softmax
+    { 3, 0, 2 }, // view->src[0]     == argsort
+    { 4, 0, 1 }, // get_rows->src[0] == reshape
+    { 4, 1, 3 }, // get_rows->src[1] == view
+};
 
+//node #652 (   ARGSORT):   ffn_moe_argsort-11 (   0K) [Vulka         ] use=1:     ffn_moe_probs-11 (   0K) [Vulka         ]
+//node #653 (      VIEW):      ffn_moe_topk-11 (   0K) [Vulka         ] use=7:   ffn_moe_argsort-11 (   0K) [Vulka         ]
+//node #654 (  GET_ROWS):   ffn_moe_weights-11 (   0K) [Vulka         ] use=1: ffn_moe_probs-11 (re (   0K) [Vulka         ]      ffn_moe_topk-11 (   0K) [Vulka         ]
+//node #655 (   RESHAPE): ffn_moe_weights-11 ( (   0K) [Vulka         ] use=1:   ffn_moe_weights-11 (   0K) [Vulka         ]
+//node #656 (  SOFT_MAX):             node_656 (   0K) [Vulka         ] use=1: ffn_moe_weights-11 ( (   0K) [Vulka         ]
+//node #657 (   RESHAPE): ffn_moe_weights_soft (   0K) [Vulka         ] use=1:             node_656 (   0K) [Vulka         ]
+static constexpr std::initializer_list<std::array<int, 3>> topk_moe_late_softmax_edges {
+    { 1, 0, 0 }, // view->src[0]     == argsort
+    { 2, 1, 1 }, // get_rows->src[1] == view
+    { 3, 0, 2 }, // reshape->src[0]  == get_rows
+    { 4, 0, 3 }, // soft_max->src[0] == reshape
+    { 5, 0, 4 }, // reshape->src[0]  == soft_max
+};
+
+enum topk_moe_mode {
+    TOPK_MOE_EARLY_SOFTMAX,
+    TOPK_MOE_EARLY_SOFTMAX_NORM,
+    TOPK_MOE_LATE_SOFTMAX,
+    TOPK_MOE_COUNT,
+};
+
+static topk_moe_mode ggml_vk_num_additional_ops_to_topk_moe_mode(uint32_t num) {
+    topk_moe_mode mode = num == topk_moe_early_softmax_norm.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX_NORM :
+                         num == topk_moe_early_softmax.size() - 1      ? TOPK_MOE_EARLY_SOFTMAX :
+                                                                         TOPK_MOE_LATE_SOFTMAX;
+    return mode;
+}
 
 struct vk_device_struct {
     std::recursive_mutex mutex;
@@ -605,8 +669,7 @@ struct vk_device_struct {
 
     vk_pipeline pipeline_flash_attn_split_k_reduce;
 
-    // [2] is {!norm, norm}
-    vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2];
+    vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT];
 
     std::vector<vk_pipeline_ref> all_pipelines;
 
@@ -954,6 +1017,8 @@ static_assert(sizeof(vk_op_multi_add_push_constants) <= 256);
 struct vk_op_topk_moe_push_constants {
     uint32_t n_rows;
     uint32_t n_expert_used;
+    float clamp_min;
+    float clamp_max;
 };
 
 struct vk_op_add_id_push_constants {
@@ -3804,8 +3869,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
 
     for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
-        ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][0], "topk_moe_f32_"+std::to_string(i),   topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0}, 1, true, true);
-        ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][1], "topk_moe_f32_"+std::to_string(i),   topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1}, 1, true, true);
+        ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX],      "topk_moe_f32_early_softmax_"+std::to_string(i),       topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0}, 1, true, true);
+        ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i),   topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0}, 1, true, true);
+        ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX],       "topk_moe_f32_late_softmax"+std::to_string(i),         topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1}, 1, true, true);
     }
 
     for (auto &c : compiles) {
@@ -8083,8 +8149,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
         if (ctx->num_additional_fused_ops) {
             uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
             GGML_ASSERT(idx < num_topk_moe_pipelines);
-            bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
-            return ctx->device->pipeline_topk_moe[idx][with_norm];
+            topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
+            return ctx->device->pipeline_topk_moe[idx][mode];
         }
 
         if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
@@ -8139,6 +8205,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
             return nullptr;
         }
     case GGML_OP_ARGSORT:
+        if (ctx->num_additional_fused_ops) {
+            uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
+            GGML_ASSERT(idx < num_topk_moe_pipelines);
+            topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
+            return ctx->device->pipeline_topk_moe[idx][mode];
+        }
+
         if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
             uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
             return ctx->device->pipeline_argsort_f32[idx];
@@ -9678,10 +9751,12 @@ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& sub
 
 static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) {
 
-    bool with_norm = ctx->num_additional_fused_ops == topk_moe_norm.size() - 1;
+    topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
     ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
-    ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
-    ggml_tensor * ids = cgraph->nodes[node_idx + 3];
+    ggml_tensor * weights = (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) ? cgraph->nodes[node_idx + 9] :
+                            (mode == TOPK_MOE_EARLY_SOFTMAX)      ? cgraph->nodes[node_idx + 4] :
+                                                                    cgraph->nodes[node_idx + 5];
+    ggml_tensor * ids = (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] : cgraph->nodes[node_idx + 3];
 
     GGML_ASSERT(logits->type == GGML_TYPE_F32);
     GGML_ASSERT(weights->type == GGML_TYPE_F32);
@@ -9740,9 +9815,14 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
         GGML_ASSERT(d_ids != nullptr);
     }
 
-    vk_op_topk_moe_push_constants pc;
+    vk_op_topk_moe_push_constants pc {};
     pc.n_rows = n_rows;
     pc.n_expert_used = n_expert_used;
+    if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) {
+        ggml_tensor * clamp = cgraph->nodes[node_idx + 7];
+        pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
+        pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
+    }
 
     GGML_ASSERT(n_expert_used <= n_experts);
 
@@ -11337,7 +11417,13 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
                 }
             }
         }
+
+#define ENABLE_SYNC_LOGGING 0
+
         if (need_sync) {
+#if ENABLE_SYNC_LOGGING
+            std::cerr <<  "sync" << std::endl;
+#endif
             ctx->unsynced_nodes_written.clear();
             ctx->unsynced_nodes_read.clear();
             ggml_vk_sync_buffers(ctx, compute_ctx);
@@ -11355,6 +11441,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
             }
         }
     }
+#if ENABLE_SYNC_LOGGING
+    if (!dryrun) {
+        for (int i = 0; i < ctx->num_additional_fused_ops + 1; ++i) {
+            auto *n = cgraph->nodes[node_idx + i];
+            std::cerr << node_idx + i << " " << ggml_op_name(n->op) << " " <<  n->name;
+            if (n->op == GGML_OP_GLU) {
+                std::cerr << " " << ggml_glu_op_name(ggml_get_glu_op(n)) << " " << (n->src[1] ? "split" : "single") << " ";
+            }
+            std::cerr << std::endl;
+        }
+    }
+#endif
 
     switch (node->op) {
     case GGML_OP_REPEAT:
@@ -11533,7 +11631,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
 
         break;
     case GGML_OP_ARGSORT:
-        ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun);
+        if (ctx->num_additional_fused_ops) {
+            ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx, dryrun);
+        } else {
+            ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun);
+        }
 
         break;
     case GGML_OP_SUM:
@@ -12306,30 +12408,27 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st
 }
 
 static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
-                                      int node_idx, bool with_norm) {
+                                      int node_idx, topk_moe_mode mode) {
 
-    if (with_norm) {
-        if (node_idx + (int)topk_moe_norm.size() > cgraph->n_nodes) {
-            return false;
-        }
-        for (size_t i = 0; i < topk_moe_norm.size(); ++i) {
-            if (cgraph->nodes[node_idx + i]->op != topk_moe_norm[i]) {
-                return false;
-            }
-        }
-    } else {
-        if (node_idx + (int)topk_moe.size() > cgraph->n_nodes) {
-            return false;
-        }
-        for (size_t i = 0; i < topk_moe.size(); ++i) {
-            if (cgraph->nodes[node_idx + i]->op != topk_moe[i]) {
-                return false;
-            }
-        }
-    }
+    const ggml_tensor * softmax;
+    const ggml_tensor * weights;
 
-    const ggml_tensor * softmax =  cgraph->nodes[node_idx + 0];
-    const ggml_tensor * weights = with_norm ? cgraph->nodes[node_idx + 8] : cgraph->nodes[node_idx + 4];
+    switch (mode) {
+    case TOPK_MOE_EARLY_SOFTMAX_NORM:
+        softmax = cgraph->nodes[node_idx + 0];
+        weights = cgraph->nodes[node_idx + 9];
+        break;
+    case TOPK_MOE_EARLY_SOFTMAX:
+        softmax = cgraph->nodes[node_idx + 0];
+        weights = cgraph->nodes[node_idx + 4];
+        break;
+    case TOPK_MOE_LATE_SOFTMAX:
+        softmax = cgraph->nodes[node_idx + 4];
+        weights = cgraph->nodes[node_idx + 5];
+        break;
+    default:
+        return false;
+    }
 
     const float * op_params = (const float *)softmax->op_params;
 
@@ -12355,60 +12454,6 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
         return false;
     }
 
-    // Check that the nodes don't have any unexpected uses
-    const ggml_tensor * reshape1 =  cgraph->nodes[node_idx + 1];
-    const ggml_tensor * argsort =   cgraph->nodes[node_idx + 2];
-    const ggml_tensor * view =      cgraph->nodes[node_idx + 3];
-    const ggml_tensor * get_rows =  cgraph->nodes[node_idx + 4];
-    const ggml_tensor * reshape5 =  with_norm ? cgraph->nodes[node_idx + 5] : nullptr;
-    const ggml_tensor * sum_rows =  with_norm ? cgraph->nodes[node_idx + 6] : nullptr;
-    const ggml_tensor * div =       with_norm ? cgraph->nodes[node_idx + 7] : nullptr;
-    const ggml_tensor * reshape8 =  with_norm ? cgraph->nodes[node_idx + 8] : nullptr;
-
-    // softmax is used by reshape and argsort
-    if (ggml_node_get_use_count(cgraph, node_idx) != 2 ||
-        reshape1->src[0] != softmax ||
-        argsort->src[0] != softmax) {
-        return false;
-    }
-    // reshape is used by get_rows
-    if (ggml_node_get_use_count(cgraph, node_idx + 1) != 1 ||
-        get_rows->src[0] != reshape1) {
-        return false;
-    }
-    // argsort is used by view
-    if (ggml_node_get_use_count(cgraph, node_idx + 2) != 1 ||
-        view->src[0] != argsort) {
-        return false;
-    }
-    // view is written (via argsort), we can skip checking it
-
-    if (with_norm) {
-        // get_rows is used by reshape
-        if (ggml_node_get_use_count(cgraph, node_idx + 4) != 1 ||
-            reshape5->src[0] != get_rows) {
-            return false;
-        }
-
-        // reshape is used by sum_rows and div
-        if (ggml_node_get_use_count(cgraph, node_idx + 5) != 2 ||
-            sum_rows->src[0] != reshape5 ||
-            div->src[0] != reshape5) {
-            return false;
-        }
-
-        // sum_rows is used by div
-        if (ggml_node_get_use_count(cgraph, node_idx + 6) != 1 ||
-            div->src[1] != sum_rows) {
-            return false;
-        }
-
-        // div/reshape are written
-        if (reshape8->src[0] != div) {
-            return false;
-        }
-    }
-
     if (!ctx->device->subgroup_arithmetic ||
         !ctx->device->subgroup_shuffle ||
         !ctx->device->subgroup_require_full_support ||
@@ -12494,10 +12539,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
                 ctx->num_additional_fused_ops = num_adds - 1;
             } else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
                 ctx->num_additional_fused_ops = 1;
-            } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
-                ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
-            } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
-                ctx->num_additional_fused_ops = topk_moe.size() - 1;
+            } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
+                       ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
+                       ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
+                ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
+            } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
+                       ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
+                       ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
+                ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
+            } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
+                       ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
+                       ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
+                ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
             }
         }
         ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
@@ -12595,10 +12648,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
                 ctx->num_additional_fused_ops = num_adds - 1;
             } else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
                 ctx->num_additional_fused_ops = 1;
-            } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, true)) {
-                ctx->num_additional_fused_ops = topk_moe_norm.size() - 1;
-            } else if (ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, false)) {
-                ctx->num_additional_fused_ops = topk_moe.size() - 1;
+            } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
+                       ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
+                       ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
+                ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
+            } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
+                       ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
+                       ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
+                ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
+            } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
+                       ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
+                       ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
+                ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
             }
         }
 
@@ -12730,25 +12791,44 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
     while (first_unused < graph->n_nodes) {
         std::vector<int> current_set;
 
-        // Avoid reordering topk_moe_norm
-        if (first_unused + (int)topk_moe_norm.size() <= graph->n_nodes) {
-            bool is_topk_moe_norm = true;
-            for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
-                if (graph->nodes[first_unused + j]->op != topk_moe_norm[j] || used[first_unused + j]) {
-                    is_topk_moe_norm = false;
+        // Check for fusion patterns and avoid reordering them
+        auto const &match_pattern = [&](const std::initializer_list<ggml_op> &pattern, int start) -> bool {
+            if (start + (int)pattern.size() <= graph->n_nodes) {
+                bool is_pattern = true;
+                for (size_t j = 0; j < pattern.size(); ++j) {
+                    if (graph->nodes[start + j]->op != pattern.begin()[j] || used[start + j]) {
+                        is_pattern = false;
+                    }
                 }
+                return is_pattern;
             }
-            if (is_topk_moe_norm) {
-                for (size_t j = 0; j < topk_moe_norm.size(); ++j) {
+            return false;
+        };
+
+        auto const &keep_pattern = [&](const std::initializer_list<ggml_op> &pattern) -> bool {
+            if (match_pattern(pattern, first_unused)) {
+                for (size_t j = 0; j < pattern.size(); ++j) {
                     new_order.push_back(graph->nodes[first_unused + j]);
                     used[first_unused + j] = true;
                 }
                 while (first_unused < graph->n_nodes && used[first_unused]) {
                     first_unused++;
                 }
-                continue;
+                return true;
             }
+            return false;
+        };
+
+        if (keep_pattern(topk_moe_early_softmax_norm)) {
+            continue;
+        }
+        if (keep_pattern(topk_moe_early_softmax)) {
+            continue;
         }
+        if (keep_pattern(topk_moe_late_softmax)) {
+            continue;
+        }
+
         // First, grab the next unused node.
         current_set.push_back(first_unused);
 
@@ -12766,6 +12846,12 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
             if (is_empty(graph->nodes[j])) {
                 continue;
             }
+            // Don't pull forward nodes from fusion patterns
+            if (match_pattern(topk_moe_early_softmax_norm, j) ||
+                match_pattern(topk_moe_early_softmax, j) ||
+                match_pattern(topk_moe_late_softmax, j)) {
+                continue;
+            }
             bool ok = true;
             for (int c = first_unused; c < j; ++c) {
                 if (!used[c] &&
index 9e56d5f8a3cc1924be9d2f88fbece875e3fc2a5c..bc1c278bf49cd589e4ecc8881ab2ecbd842a381d 100644 (file)
@@ -11,6 +11,8 @@ layout (push_constant) uniform parameter
 {
     uint n_rows;
     uint n_expert_used;
+    float clamp_min;
+    float clamp_max;
 };
 
 layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
@@ -18,6 +20,7 @@ layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
 layout(constant_id = 0) const uint WARP_SIZE = 32;
 layout(constant_id = 1) const uint n_experts = 512;
 layout(constant_id = 2) const bool with_norm = true;
+layout(constant_id = 3) const bool late_softmax = false;
 
 const uint experts_per_thread = (n_experts > WARP_SIZE) ? n_experts / WARP_SIZE : 1;
 
@@ -25,53 +28,72 @@ layout (binding = 0, std430) readonly buffer Logits {float logits[];};
 layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
 layout (binding = 2, std430) writeonly buffer Ids {uint ids[];};
 
-void main() {
-    const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
-    if (row >= n_rows) {
-        return;
-    }
+const float INFINITY = 1.0 / 0.0;
 
-    const uint logits_offset = n_experts * row;
-    const uint weights_offset = n_expert_used * row;
-    const uint ids_offset = n_experts * row;
-
-    float logits_r[experts_per_thread];
-
-    const float INFINITY = 1.0 / 0.0;
+// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
+void softmax_warp_inplace(inout float vals[experts_per_thread], const uint limit, const uint lane, const bool use_limit) {
+    float max_val = -INFINITY;
 
     [[unroll]]
-    for (uint i = 0; i < n_experts; i += WARP_SIZE) {
-        const uint expert        = i + gl_LocalInvocationID.x;
-        logits_r[i / WARP_SIZE] = n_experts % WARP_SIZE == 0 || expert < n_experts ? logits[logits_offset + expert] : -INFINITY;
+    for (int i = 0; i < experts_per_thread; i++) {
+        const uint idx       = lane + i * WARP_SIZE;
+        const bool is_active = !use_limit || (idx < limit);
+        if (is_active) {
+            max_val = max(max_val, vals[i]);
+        }
     }
 
-    float max_val = logits_r[0];
+    max_val = subgroupMax(max_val);
+
+    float sum = 0.f;
 
     [[unroll]]
-    for (int i = 1; i < experts_per_thread; i++) {
-        const float val = logits_r[i];
-        max_val         = max(val, max_val);
+    for (int i = 0; i < experts_per_thread; i++) {
+        const uint idx       = lane + i * WARP_SIZE;
+        const bool is_active = !use_limit || (idx < limit);
+        if (is_active) {
+            const float val = exp(vals[i] - max_val);
+            vals[i]         = val;
+            sum += val;
+        } else {
+            vals[i] = 0.f;
+        }
     }
 
-    max_val = subgroupMax(max_val);
+    sum = subgroupAdd(sum);
 
-    float wt[experts_per_thread];
-    float tmp = 0.f;
+    const float inv_sum = 1.0f / sum;
 
     [[unroll]]
     for (int i = 0; i < experts_per_thread; i++) {
-        const float val = logits_r[i];
-        wt[i]           = exp(val - max_val);
-        tmp += wt[i];
+        const uint idx       = lane + i * WARP_SIZE;
+        const bool is_active = !use_limit || (idx < limit);
+        if (is_active) {
+            vals[i] *= inv_sum;
+        }
     }
+}
 
-    tmp = subgroupAdd(tmp);
+void main() {
+    const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y;
+    if (row >= n_rows) {
+        return;
+    }
 
-    const float inv_sum = 1.0f / tmp;
+    const uint logits_offset = n_experts * row;
+    const uint weights_offset = n_expert_used * row;
+    const uint ids_offset = n_experts * row;
+
+    float wt[experts_per_thread];
 
     [[unroll]]
-    for (int i = 0; i < experts_per_thread; i++) {
-        wt[i] = wt[i] * inv_sum;
+    for (uint i = 0; i < n_experts; i += WARP_SIZE) {
+        const uint expert = i + gl_LocalInvocationID.x;
+        wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
+    }
+
+    if (!late_softmax) {
+        softmax_warp_inplace(wt, n_experts, gl_LocalInvocationID.x, false);
     }
 
     // at this point, each thread holds a portion of softmax,
@@ -82,6 +104,11 @@ void main() {
 
     float output_weights[experts_per_thread];
 
+    [[unroll]]
+    for (int i = 0; i < experts_per_thread; i++) {
+        output_weights[i] = 0.f;
+    }
+
     for (int k = 0; k < n_expert_used; k++) {
         float max_val    = wt[0];
         uint   max_expert = gl_LocalInvocationID.x;
@@ -121,6 +148,7 @@ void main() {
 
     if (with_norm) {
         wt_sum              = subgroupAdd(wt_sum);
+        wt_sum              = clamp(wt_sum, clamp_min, clamp_max);
         const float inv_sum = 1.0f / wt_sum;
 
         [[unroll]]
@@ -129,6 +157,10 @@ void main() {
         }
     }
 
+    if (late_softmax) {
+        softmax_warp_inplace(output_weights, n_expert_used, gl_LocalInvocationID.x, true);
+    }
+
     [[unroll]]
     for (uint i = 0; i < experts_per_thread; ++i) {
         uint idx = i * WARP_SIZE + gl_LocalInvocationID.x;