]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
vulkan: extend topk_moe to handle sigmoid w/exp_probs_b for nemotron (llama/18295)
authorJeff Bolz <redacted>
Thu, 1 Jan 2026 07:58:27 +0000 (01:58 -0600)
committerGeorgi Gerganov <redacted>
Wed, 14 Jan 2026 07:11:59 +0000 (09:11 +0200)
* vulkan: extend topk_moe to handle sigmoid w/exp_probs_b for nemotron

Also handle GGML_OP_SCALE at the end (nemotron, deepseek2).

Fewer pipeline variants and spec constants, just use push constants.

In test_topk_moe, change exp_probs_b to be 1D, matching real networks.

Update test-backend-ops and ggml-backend to allow verifying multiple outputs
in a fusion test (topk_moe has two outputs). Previously only the final node
was verified.

* change test_topk_moe to allow results in arbitrary order

* disable sigmoid fusion for moltenvk

ggml/include/ggml-backend.h
ggml/src/ggml-backend.cpp
ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp

index 4ed5f35774ffcbb462cab7d06e28d3ff7d2a49b7..a9d1778641e8ca2c0da88f158dea4ed1787b2cb0 100644 (file)
@@ -358,7 +358,7 @@ extern "C" {
     typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
 
     // Compare the output of two backends
-    GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node);
+    GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor const * const * test_nodes, size_t num_test_nodes);
 
     // Tensor initialization
     GGML_API enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
index 8547ecc849c6fea7c013a332c3fac9ef25bdb4f7..1b59924b8cb67c2b749d84b444ccf7b0031ef5ac 100644 (file)
@@ -2053,7 +2053,7 @@ void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy) {
     ggml_free(copy.ctx_unallocated);
 }
 
-bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node) {
+bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor const * const * test_nodes, size_t num_test_nodes) {
     struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph);
     if (copy.buffer == NULL) {
         return false;
@@ -2064,22 +2064,22 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t
 
     assert(g1->n_nodes == g2->n_nodes);
 
-    if (test_node != nullptr) {
-        // Compute the whole graph and only test the output for a specific tensor
+    if (num_test_nodes != 0) {
+        GGML_ASSERT(test_nodes);
+        // Compute the whole graph and only test the output for specific tensors
         ggml_backend_graph_compute(backend1, g1);
         ggml_backend_graph_compute(backend2, g2);
 
-        int test_node_idx = -1;
+        bool verified = false;
         for (int i = 0; i < g1->n_nodes; i++) {
-            struct ggml_tensor * t1 = g1->nodes[i];
-            if (t1 == test_node) {
-                test_node_idx = i;
-                break;
+            for (size_t j = 0; j < num_test_nodes; ++j) {
+                if (g1->nodes[i] == test_nodes[j]) {
+                    callback(i, g1->nodes[i], g2->nodes[i], user_data);
+                    verified = true;
+                }
             }
         }
-        GGML_ASSERT(test_node_idx != -1);
-
-        callback(test_node_idx, g1->nodes[test_node_idx], g2->nodes[test_node_idx], user_data);
+        GGML_ASSERT(verified);
     } else {
         for (int i = 0; i < g1->n_nodes; i++) {
             struct ggml_tensor * t1 = g1->nodes[i];
index 493ee9c9a44f78e4487a59389f61b6fd8c8d35cf..541e4a50b705bb6a1274de548a85811d4c32f18d 100644 (file)
@@ -434,8 +434,15 @@ static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGM
                                                                              GGML_OP_VIEW,     GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
                                                                              GGML_OP_SUM_ROWS, GGML_OP_CLAMP,    GGML_OP_DIV,
                                                                              GGML_OP_RESHAPE };
+
+static constexpr std::initializer_list<ggml_op> topk_moe_sigmoid_norm_bias{ GGML_OP_UNARY,    GGML_OP_RESHAPE,  GGML_OP_ADD,
+                                                                            GGML_OP_ARGSORT,  GGML_OP_VIEW,     GGML_OP_GET_ROWS,
+                                                                            GGML_OP_RESHAPE,  GGML_OP_SUM_ROWS, GGML_OP_CLAMP,
+                                                                            GGML_OP_DIV,      GGML_OP_RESHAPE };
+
 static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax     { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE,  GGML_OP_ARGSORT,
                                                                              GGML_OP_VIEW,     GGML_OP_GET_ROWS };
+
 static constexpr std::initializer_list<ggml_op> topk_moe_late_softmax      { GGML_OP_ARGSORT,  GGML_OP_VIEW,
                                                                              GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
                                                                              GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
@@ -464,6 +471,32 @@ static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softma
     { 9, 0, 8 }, // reshape->src[0]  == div
 };
 
+//node #436 (     UNARY):     ffn_moe_probs-10 ( 256K) [Vulka         ] use=2:    ffn_moe_logits-10 ( 256K) [Vulka         ]
+//node #437 (   RESHAPE): ffn_moe_probs-10 (re ( 256K) [Vulka         ] use=1:     ffn_moe_probs-10 ( 256K) [Vulka         ]
+//node #438 (       ADD): ffn_moe_probs_biased ( 256K) [Vulka         ] use=1:     ffn_moe_probs-10 ( 256K) [Vulka         ] blk.10.exp_probs_b.b (   0K) [Vulka         ]
+//node #439 (   ARGSORT):   ffn_moe_argsort-10 ( 256K) [Vulka         ] use=1: ffn_moe_probs_biased ( 256K) [Vulka         ]
+//node #440 (      VIEW):      ffn_moe_topk-10 ( 255K) [Vulka         ] use=3:   ffn_moe_argsort-10 ( 256K) [Vulka         ]
+//node #441 (  GET_ROWS):   ffn_moe_weights-10 (  12K) [Vulka         ] use=1: ffn_moe_probs-10 (re ( 256K) [Vulka         ]      ffn_moe_topk-10 ( 255K) [Vulka         ]
+//node #442 (   RESHAPE): ffn_moe_weights-10 ( (  12K) [Vulka         ] use=2:   ffn_moe_weights-10 (  12K) [Vulka         ]
+//node #443 (  SUM_ROWS): ffn_moe_weights_sum- (   2K) [Vulka         ] use=1: ffn_moe_weights-10 ( (  12K) [Vulka         ]
+//node #444 (     CLAMP): ffn_moe_weights_sum_ (   2K) [Vulka         ] use=1: ffn_moe_weights_sum- (   2K) [Vulka         ]
+//node #445 (       DIV): ffn_moe_weights_norm (  12K) [Vulka         ] use=1: ffn_moe_weights-10 ( (  12K) [Vulka         ] ffn_moe_weights_sum_ (   2K) [Vulka         ]
+//node #446 (   RESHAPE): ffn_moe_weights_norm (  12K) [Vulka         ] use=1: ffn_moe_weights_norm (  12K) [Vulka         ]
+static constexpr std::initializer_list<std::array<int, 3>> topk_moe_sigmoid_norm_bias_edges {
+    { 1, 0, 0 }, // reshape->src[0]  == sigmoid
+    { 2, 0, 0 }, // add->src[0]      == sigmoid
+    { 3, 0, 2 }, // argsort->src[0]  == add
+    { 4, 0, 3 }, // view->src[0]     == argsort
+    { 5, 0, 1 }, // get_rows->src[0] == reshape
+    { 5, 1, 4 }, // get_rows->src[1] == view
+    { 6, 0, 5 }, // reshape->src[0]  == get_rows
+    { 7, 0, 6 }, // sum_rows->src[0] == reshape
+    { 8, 0, 7 }, // clamp->src[0]    == sum_rows
+    { 9, 0, 6 }, // div->src[0]      == reshape
+    { 9, 1, 8 }, // div->src[1]      == clamp
+    {10, 0, 9 }, // reshape->src[0]  == div
+};
+
 // same as early_softmax_norm but ending after the get_rows
 static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_edges {
     { 1, 0, 0 }, // reshape->src[0]  == softmax
@@ -491,16 +524,10 @@ enum topk_moe_mode {
     TOPK_MOE_EARLY_SOFTMAX,
     TOPK_MOE_EARLY_SOFTMAX_NORM,
     TOPK_MOE_LATE_SOFTMAX,
+    TOPK_MOE_SIGMOID_NORM_BIAS,
     TOPK_MOE_COUNT,
 };
 
-static topk_moe_mode ggml_vk_num_additional_ops_to_topk_moe_mode(uint32_t num) {
-    topk_moe_mode mode = num == topk_moe_early_softmax_norm.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX_NORM :
-                         num == topk_moe_early_softmax.size() - 1      ? TOPK_MOE_EARLY_SOFTMAX :
-                                                                         TOPK_MOE_LATE_SOFTMAX;
-    return mode;
-}
-
 static constexpr std::initializer_list<std::array<int, 3>> rope_view_set_rows_edges {
     { 1, 0, 0 }, // view->src[0]     == rope
     { 2, 0, 1 }, // set_rows->src[0] == view
@@ -766,7 +793,7 @@ struct vk_device_struct {
     vk_pipeline pipeline_count_experts;
 
     // [2] is for whether to take n_experts from spec constant (0) or push constant (1)
-    vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT][2];
+    vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2];
 
     std::vector<vk_pipeline_ref> all_pipelines;
 
@@ -1181,6 +1208,11 @@ struct vk_op_topk_moe_push_constants {
     uint32_t n_expert_used;
     float clamp_min;
     float clamp_max;
+    uint32_t gating_func;
+    uint32_t has_bias;
+    uint32_t with_norm;
+    float output_scale;
+    float output_bias;
 };
 
 struct vk_op_add_id_push_constants {
@@ -1771,6 +1803,8 @@ struct ggml_backend_vk_context {
     // Bit 'i' means nodes[start_of_fusion + i] writes to memory.
     // If there's no fusion, bit 0 is still set.
     int fused_ops_write_mask {};
+    topk_moe_mode fused_topk_moe_mode {};
+    bool fused_topk_moe_scale {};
 
     // for GGML_VK_PERF_LOGGER
     std::unique_ptr<vk_perf_logger> perf_logger;
@@ -4291,9 +4325,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
     for (uint32_t use_push = 0; use_push < 2; ++use_push) {
         for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
-            ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX][use_push],      "topk_moe_f32_early_softmax_"+std::to_string(i),       topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0, use_push}, 1, true, true, device->subgroup_size);
-            ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM][use_push], "topk_moe_f32_early_softmax_norm"+std::to_string(i),   topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0, use_push}, 1, true, true, device->subgroup_size);
-            ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX][use_push],       "topk_moe_f32_late_softmax"+std::to_string(i),         topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1, use_push}, 1, true, true, device->subgroup_size);
+            ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][use_push], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 4, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, use_push}, 1, true, true, device->subgroup_size);
         }
     }
 
@@ -8684,10 +8716,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
         if (ctx->num_additional_fused_ops) {
             uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
             GGML_ASSERT(idx < num_topk_moe_pipelines);
-            topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
             // use n_experts from push constant if it's not equal to the power of two spec constant
             bool use_push = dst->ne[0] != (1u << idx);
-            return ctx->device->pipeline_topk_moe[idx][mode][use_push];
+            return ctx->device->pipeline_topk_moe[idx][use_push];
         }
 
         if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
@@ -10346,14 +10377,16 @@ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& sub
 }
 
 static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) {
-    topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
+    topk_moe_mode mode = ctx->fused_topk_moe_mode;
     ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
-    ggml_tensor * weights = (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) ? cgraph->nodes[node_idx + 9] :
-                            (mode == TOPK_MOE_EARLY_SOFTMAX)      ? cgraph->nodes[node_idx + 4] :
-                                                                    cgraph->nodes[node_idx + 5];
-    ggml_tensor * ids = (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] : cgraph->nodes[node_idx + 3];
+    ggml_tensor * bias = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 2]->src[1] : logits;
+    ggml_tensor * weights = cgraph->nodes[node_idx + ctx->num_additional_fused_ops];
+    ggml_tensor * ids = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 4] :
+                        (mode == TOPK_MOE_LATE_SOFTMAX) ?      cgraph->nodes[node_idx + 1] :
+                                                               cgraph->nodes[node_idx + 3];
 
     GGML_ASSERT(logits->type == GGML_TYPE_F32);
+    GGML_ASSERT(bias->type == GGML_TYPE_F32);
     GGML_ASSERT(weights->type == GGML_TYPE_F32);
     GGML_ASSERT(ids->type == GGML_TYPE_I32);
 
@@ -10368,6 +10401,7 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
     ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
 
     vk_subbuffer logits_buf = ggml_vk_tensor_subbuffer(ctx, logits);
+    vk_subbuffer bias_buf = ggml_vk_tensor_subbuffer(ctx, bias);
     vk_subbuffer weights_buf = ggml_vk_tensor_subbuffer(ctx, weights);
     vk_subbuffer ids_buf = ggml_vk_tensor_subbuffer(ctx, ids);
 
@@ -10375,18 +10409,45 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
     pc.n_rows = n_rows;
     pc.n_experts_push = n_experts;
     pc.n_expert_used = n_expert_used;
+    pc.clamp_min = -std::numeric_limits<float>::infinity();
+    pc.clamp_max = std::numeric_limits<float>::infinity();
     if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) {
         ggml_tensor * clamp = cgraph->nodes[node_idx + 7];
+        GGML_ASSERT(clamp->op == GGML_OP_CLAMP);
+        pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
+        pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
+    }
+    if (mode == TOPK_MOE_SIGMOID_NORM_BIAS) {
+        ggml_tensor * clamp = cgraph->nodes[node_idx + 8];
+        GGML_ASSERT(clamp->op == GGML_OP_CLAMP);
         pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
         pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
     }
 
+#define GATING_FUNC_SOFTMAX 0
+#define GATING_FUNC_SIGMOID 1
+#define GATING_FUNC_SOFTMAX_WEIGHT 2
+
+    pc.gating_func = mode == TOPK_MOE_SIGMOID_NORM_BIAS ? GATING_FUNC_SIGMOID :
+                     mode == TOPK_MOE_LATE_SOFTMAX ?      GATING_FUNC_SOFTMAX_WEIGHT :
+                                                          GATING_FUNC_SOFTMAX;
+    pc.has_bias = mode == TOPK_MOE_SIGMOID_NORM_BIAS;
+    pc.with_norm = mode == TOPK_MOE_EARLY_SOFTMAX_NORM || mode == TOPK_MOE_SIGMOID_NORM_BIAS;
+    if (ctx->fused_topk_moe_scale) {
+        GGML_ASSERT(weights->op == GGML_OP_SCALE);
+        pc.output_scale = ggml_get_op_params_f32(weights, 0);
+        pc.output_bias = ggml_get_op_params_f32(weights, 1);
+    } else {
+        pc.output_scale = 1.0f;
+        pc.output_bias = 0.0f;
+    }
+
     GGML_ASSERT(n_expert_used <= n_experts);
 
     const uint32_t rows_per_block = 4;
     std::array<uint32_t, 3> elements = { CEIL_DIV(n_rows, rows_per_block), 1, 1 };
 
-    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {logits_buf, weights_buf, ids_buf}, pc, elements);
+    ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {logits_buf, bias_buf, weights_buf, ids_buf}, pc, elements);
 }
 
 static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_cgraph * cgraph, int node_idx, bool backprop) {
@@ -12128,6 +12189,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
 
         break;
     case GGML_OP_UNARY:
+        if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
+            ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
+            break;
+        }
+
         switch (ggml_get_unary_op(node)) {
         case GGML_UNARY_OP_EXP:
         case GGML_UNARY_OP_SILU:
@@ -12175,7 +12241,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
 
         break;
     case GGML_OP_SOFT_MAX:
-        if (ctx->num_additional_fused_ops) {
+        if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
             ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
         } else {
             ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node);
@@ -12195,7 +12261,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
 
         break;
     case GGML_OP_ARGSORT:
-        if (ctx->num_additional_fused_ops) {
+        if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
             ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
         } else {
             ggml_vk_argsort(ctx, compute_ctx, src0, node);
@@ -13048,6 +13114,24 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
         get_rows = cgraph->nodes[node_idx + 4];
         argsort = cgraph->nodes[node_idx + 2];
         break;
+    case TOPK_MOE_SIGMOID_NORM_BIAS:
+        softmax = cgraph->nodes[node_idx + 0]; // really sigmoid
+        weights = cgraph->nodes[node_idx + 10];
+        get_rows = cgraph->nodes[node_idx + 5];
+        argsort = cgraph->nodes[node_idx + 3];
+        if (ggml_get_unary_op(softmax) != GGML_UNARY_OP_SIGMOID) {
+            return false;
+        }
+        // bias is expected to be 1D
+        if (ggml_nrows(cgraph->nodes[node_idx + 2]->src[1]) != 1 ||
+            !ggml_is_contiguous(cgraph->nodes[node_idx + 2]->src[1])) {
+            return false;
+        }
+        // sigmoid fusion seems to generate infinities on moltenvk
+        if (ctx->device->driver_id == vk::DriverId::eMoltenvk) {
+            return false;
+        }
+        break;
     case TOPK_MOE_EARLY_SOFTMAX:
         softmax = cgraph->nodes[node_idx + 0];
         weights = cgraph->nodes[node_idx + 4];
@@ -13071,26 +13155,28 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
     probs = probs->src[0];
     ggml_tensor * selection_probs = argsort->src[0];
 
-    if (probs != selection_probs) {
+    if (probs != selection_probs && mode != TOPK_MOE_SIGMOID_NORM_BIAS) {
         return false;
     }
 
-    const float * op_params = (const float *)softmax->op_params;
-
-    float scale = op_params[0];
-    float max_bias = op_params[1];
-
     if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
         return false;
     }
 
-    if (scale != 1.0f || max_bias != 0.0f) {
-        return false;
-    }
+    if (softmax->op == GGML_OP_SOFT_MAX) {
+        const float * op_params = (const float *)softmax->op_params;
 
-    // don't fuse when masks or sinks are present
-    if (softmax->src[1] || softmax->src[2]) {
-        return false;
+        float scale = op_params[0];
+        float max_bias = op_params[1];
+
+        if (scale != 1.0f || max_bias != 0.0f) {
+            return false;
+        }
+
+        // don't fuse when masks or sinks are present
+        if (softmax->src[1] || softmax->src[2]) {
+            return false;
+        }
     }
 
     const int n_expert = softmax->ne[0];
@@ -13363,6 +13449,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
             total_mul_mat_bytes += bytes;
         }
 
+        ctx->fused_topk_moe_mode = TOPK_MOE_COUNT;
+        ctx->fused_topk_moe_scale = false;
         const char *fusion_string {};
         if (!ctx->device->disable_fusion) {
             uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i);
@@ -13408,13 +13496,23 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
                 ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
                 // view of argsort writes to memory
                 ctx->fused_ops_write_mask |= 1 << 3;
+                ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX_NORM;
                 fusion_string = "TOPK_MOE_EARLY_SOFTMAX_NORM";
+            } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_sigmoid_norm_bias, { i + 4, i + 10 }) &&
+                       ggml_check_edges(cgraph, i, topk_moe_sigmoid_norm_bias_edges) &&
+                       ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_SIGMOID_NORM_BIAS)) {
+                ctx->num_additional_fused_ops = topk_moe_sigmoid_norm_bias.size() - 1;
+                // view of argsort writes to memory
+                ctx->fused_ops_write_mask |= 1 << 4;
+                ctx->fused_topk_moe_mode = TOPK_MOE_SIGMOID_NORM_BIAS;
+                fusion_string = "TOPK_MOE_SIGMOID_NORM_BIAS";
             } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
                        ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
                        ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
                 ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
                 // view of argsort writes to memory
                 ctx->fused_ops_write_mask |= 1 << 3;
+                ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX;
                 fusion_string = "TOPK_MOE_EARLY_SOFTMAX";
             } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
                        ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
@@ -13422,8 +13520,17 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
                 ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
                 // view of argsort writes to memory
                 ctx->fused_ops_write_mask |= 1 << 1;
+                ctx->fused_topk_moe_mode = TOPK_MOE_LATE_SOFTMAX;
                 fusion_string = "TOPK_MOE_LATE_SOFTMAX";
             }
+            if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
+                // Look for an additional scale op to fuse - occurs in deepseek2 and nemotron3 nano.
+                if (ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops - 1, { GGML_OP_DIV, GGML_OP_RESHAPE, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 }) ||
+                    ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops, { GGML_OP_GET_ROWS, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 })) {
+                    ctx->fused_topk_moe_scale = true;
+                    ctx->num_additional_fused_ops++;
+                }
+            }
         }
         ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops;
 
@@ -13602,6 +13709,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
         if (keep_pattern(topk_moe_early_softmax_norm)) {
             continue;
         }
+        if (keep_pattern(topk_moe_sigmoid_norm_bias)) {
+            continue;
+        }
         if (keep_pattern(topk_moe_early_softmax)) {
             continue;
         }
@@ -13628,6 +13738,7 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
             }
             // Don't pull forward nodes from fusion patterns
             if (match_pattern(topk_moe_early_softmax_norm, j) ||
+                match_pattern(topk_moe_sigmoid_norm_bias, j) ||
                 match_pattern(topk_moe_early_softmax, j) ||
                 match_pattern(topk_moe_late_softmax, j)) {
                 continue;
index b83a2b9d2d4f7e8baa680cca3f673390aa9f3325..4bf6d2bcb03eb6b86918622edc5d58f6224974c2 100644 (file)
@@ -7,6 +7,10 @@
 
 #include "types.glsl"
 
+#define GATING_FUNC_SOFTMAX 0
+#define GATING_FUNC_SIGMOID 1
+#define GATING_FUNC_SOFTMAX_WEIGHT 2
+
 layout (push_constant) uniform parameter
 {
     uint n_rows;
@@ -14,15 +18,18 @@ layout (push_constant) uniform parameter
     uint n_expert_used;
     float clamp_min;
     float clamp_max;
+    uint gating_func;
+    uint has_bias;
+    uint with_norm;
+    float output_scale;
+    float output_bias;
 };
 
 layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
 
 layout(constant_id = 0) const uint WARP_SIZE = 32;
 layout(constant_id = 1) const uint n_experts_spec = 512;
-layout(constant_id = 2) const bool with_norm = true;
-layout(constant_id = 3) const bool late_softmax = false;
-layout(constant_id = 4) const bool nexperts_use_push = false;
+layout(constant_id = 2) const bool nexperts_use_push = false;
 
 uint n_experts = nexperts_use_push ? n_experts_push : n_experts_spec;
 
@@ -31,8 +38,9 @@ uint n_experts = nexperts_use_push ? n_experts_push : n_experts_spec;
 const uint experts_per_thread = CEIL_DIV(n_experts_spec, WARP_SIZE);
 
 layout (binding = 0, std430) readonly buffer Logits {float logits[];};
-layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
-layout (binding = 2, std430) writeonly buffer Ids {uint ids[];};
+layout (binding = 1, std430) readonly buffer BiasProbs {float bias[];};
+layout (binding = 2, std430) writeonly buffer Weights {float weights[];};
+layout (binding = 3, std430) writeonly buffer Ids {uint ids[];};
 
 const float INFINITY = 1.0 / 0.0;
 
@@ -87,20 +95,40 @@ void main() {
     }
 
     const uint logits_offset = n_experts * row;
+    const uint bias_offset = 0; // 1D
     const uint weights_offset = n_expert_used * row;
     const uint ids_offset = n_experts * row;
     const uint lane = gl_SubgroupInvocationID;
 
-    float wt[experts_per_thread];
+    float probs[experts_per_thread];
 
     [[unroll]]
     for (uint i = 0; i < n_experts; i += WARP_SIZE) {
         const uint expert = i + lane;
-        wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
+        probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
+    }
+
+    if (gating_func == GATING_FUNC_SOFTMAX) {
+        softmax_warp_inplace(probs, n_experts, lane, nexperts_use_push);
+    } else if (gating_func == GATING_FUNC_SIGMOID) {
+        [[unroll]]
+        for (int i = 0; i < experts_per_thread; i++) {
+            probs[i] = 1.f / (1.f + exp(-probs[i]));
+        }
     }
 
-    if (!late_softmax) {
-        softmax_warp_inplace(wt, n_experts, lane, nexperts_use_push);
+    float selection_probs[experts_per_thread];
+    if (has_bias != 0) {
+        [[unroll]]
+        for (uint i = 0; i < n_experts; i += WARP_SIZE) {
+            const uint expert = i + lane;
+            selection_probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? probs[i / WARP_SIZE] + bias[bias_offset + expert] : -INFINITY;
+        }
+    } else {
+        [[unroll]]
+        for (int i = 0; i < experts_per_thread; i++) {
+            selection_probs[i] = probs[i];
+        }
     }
 
     // at this point, each thread holds a portion of softmax,
@@ -117,14 +145,16 @@ void main() {
     }
 
     for (int k = 0; k < n_expert_used; k++) {
-        float max_val    = wt[0];
+        float max_val    = probs[0];
+        float max_val_s  = selection_probs[0];
         uint   max_expert = lane;
 
         [[unroll]]
         for (int i = 1; i < experts_per_thread; i++) {
             const uint expert = lane + i * WARP_SIZE;
-            if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
-                max_val    = wt[i];
+            if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_probs[i] > max_val_s) {
+                max_val    = probs[i];
+                max_val_s  = selection_probs[i];
                 max_expert = expert;
             }
         }
@@ -132,9 +162,11 @@ void main() {
         [[unroll]]
         for (uint mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
             const float val    = subgroupShuffleXor(max_val, mask);
+            const float val_s  = subgroupShuffleXor(max_val_s, mask);
             const uint  expert = subgroupShuffleXor(max_expert, mask);
-            if (val > max_val || (val == max_val && expert < max_expert)) {
+            if (val_s > max_val_s || (val_s == max_val_s && expert < max_expert)) {
                 max_val    = val;
+                max_val_s  = val_s;
                 max_expert = expert;
             }
         }
@@ -144,16 +176,14 @@ void main() {
         }
 
         if ((max_expert & (WARP_SIZE - 1)) == lane) {
-            wt[max_expert / WARP_SIZE] = -INFINITY;
+            selection_probs[max_expert / WARP_SIZE] = -INFINITY;
 
             ids[ids_offset + k] = max_expert;
-            if (with_norm) {
-                wt_sum += max_val;
-            }
+            wt_sum += max_val;
         }
     }
 
-    if (with_norm) {
+    if (with_norm != 0) {
         wt_sum              = subgroupAdd(wt_sum);
         wt_sum              = clamp(wt_sum, clamp_min, clamp_max);
         const float inv_sum = 1.0f / wt_sum;
@@ -164,7 +194,7 @@ void main() {
         }
     }
 
-    if (late_softmax) {
+    if (gating_func == GATING_FUNC_SOFTMAX_WEIGHT) {
         softmax_warp_inplace(output_weights, n_expert_used, lane, true);
     }
 
@@ -172,7 +202,7 @@ void main() {
     for (uint i = 0; i < experts_per_thread; ++i) {
         uint idx = i * WARP_SIZE + lane;
         if (idx < n_expert_used) {
-            weights[weights_offset + idx] = output_weights[i];
+            weights[weights_offset + idx] = output_scale * output_weights[i] + output_bias;
         }
     }
 }