]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
vulkan: fuse mul_mat_id + mul (llama/17095)
authorJeff Bolz <redacted>
Sun, 9 Nov 2025 08:48:42 +0000 (02:48 -0600)
committerGeorgi Gerganov <redacted>
Sun, 9 Nov 2025 16:30:22 +0000 (18:30 +0200)
* vulkan: fuse mul_mat_id + mul

This comes up in qwen3 moe.

* split mul_mat_id fusion tests into a separate class

src/ggml-vulkan/ggml-vulkan.cpp
src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl
tests/test-backend-ops.cpp

index 6da7bbd2f611df43e63cecbf97a459b715af757a..054e8cbdb8b74b2559540e69f37f06e7d5592275 100644 (file)
@@ -830,6 +830,7 @@ struct vk_mat_vec_push_constants {
     uint32_t batch_stride_b;
     uint32_t batch_stride_d;
     uint32_t enable_bias;
+    uint32_t enable_scale;
     uint32_t ne02;
     uint32_t ne12;
     uint32_t broadcast2;
@@ -852,6 +853,7 @@ struct vk_mat_vec_id_push_constants {
     uint32_t batch_stride_b;
     uint32_t batch_stride_d;
     uint32_t enable_bias;
+    uint32_t enable_scale;
     uint32_t nei0;
     uint32_t ne11;
 };
@@ -6863,7 +6865,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
     // compute
     const vk_mat_vec_push_constants pc = {
         (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
-        stride_batch_x, stride_batch_y, stride_batch_d, enable_bias,
+        stride_batch_x, stride_batch_y, stride_batch_d, enable_bias, 0,
         (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
     };
     ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
@@ -7684,13 +7686,22 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
         groups_x = CEIL_DIV(groups_x, groups_z);
     }
 
-    uint32_t enable_bias = ctx->num_additional_fused_ops > 0;
+    uint32_t enable_bias = 0;
+    uint32_t enable_scale = 0;
+    if (ctx->num_additional_fused_ops > 0) {
+        if (cgraph->nodes[node_idx + 1]->op == GGML_OP_MUL) {
+            enable_scale = 1;
+        } else {
+            GGML_ASSERT(cgraph->nodes[node_idx + 1]->op == GGML_OP_ADD_ID);
+            enable_bias = 1;
+        }
+    }
 
     vk_buffer d_B = d_D;
     size_t b_buf_offset = 0;
     uint64_t b_sz = 0;
 
-    if (enable_bias) {
+    if (enable_bias || enable_scale) {
         const ggml_tensor * bias = cgraph->nodes[node_idx + 1]->src[1];
 
         bool b_uma = false;
@@ -7712,7 +7723,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
         (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
         (uint32_t)x_ne, stride_batch_y, (uint32_t)(ne20*ne21),
 
-        enable_bias,
+        enable_bias, enable_scale,
 
         (uint32_t)nei0, (uint32_t)ne11,
     };
@@ -12490,6 +12501,40 @@ static bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct g
         }
     }
 
+    if (ops.size() == 2 && ops.begin()[0] == GGML_OP_MUL_MAT_ID && ops.begin()[1] == GGML_OP_MUL) {
+        // additional constraints specific to this fusion
+        const ggml_tensor *mmid = cgraph->nodes[node_idx];
+        const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
+        const ggml_tensor *scale = mul->src[1];
+
+        if (mmid != mul->src[0]) {
+            return false;
+        }
+        // mat-vec only
+        if (!ggml_vk_use_mul_mat_vec_id(cgraph, node_idx)) {
+            return false;
+        }
+        // shaders assume the types match
+        if (mmid->type != scale->type) {
+            return false;
+        }
+        // shaders assume the bias is contiguous
+        if (!ggml_is_contiguous(scale)) {
+            return false;
+        }
+        // unaligned bias isn't handled
+        if (get_misalign_bytes(ctx, scale) != 0) {
+            return false;
+        }
+        // shader only indexes by expert index
+        if (scale->ne[0] != 1 ||
+            scale->ne[1] != mul->ne[1] ||
+            scale->ne[2] != 1 ||
+            scale->ne[3] != 1) {
+            return false;
+        }
+    }
+
     return true;
 }
 
@@ -12798,6 +12843,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
                 ctx->num_additional_fused_ops = 1;
             } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) {
                 ctx->num_additional_fused_ops = 1;
+            } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_MUL })) {
+                ctx->num_additional_fused_ops = 1;
             } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 4 }) &&
                        ggml_check_edges(cgraph, i, rms_norm_mul_rope_view_set_rows_edges) &&
                        ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i) &&
@@ -13033,7 +13080,8 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
                     is_src_of(graph->nodes[j], graph->nodes[c]) &&
                     !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_RMS_NORM && graph->nodes[j]->op == GGML_OP_MUL) &&
                     !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT && graph->nodes[j]->op == GGML_OP_ADD) &&
-                    !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID)) {
+                    !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID) &&
+                    !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL)) {
                     ok = false;
                     break;
                 }
index bbb4d1206b7e44ab37724500de28d0f9a640d842..eb8fa6dc09fb1c324e257b2f574dc94c169748d2 100644 (file)
@@ -49,6 +49,7 @@ layout (push_constant) uniform parameter
     uint batch_stride_d;
 
     uint enable_bias;
+    uint enable_scale;
 
 #ifdef MUL_MAT_ID
     uint nei0;
@@ -129,6 +130,12 @@ void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t
                     temp[j][n] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
 #endif
                 }
+#ifdef MUL_MAT_ID
+                if (p.enable_scale != 0) {
+                    const uint expert_idx = gl_GlobalInvocationID.y;
+                    temp[j][n] *= FLOAT_TYPE(data_bias[expert_idx]);
+                }
+#endif
                 data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
             }
         }
@@ -171,6 +178,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
                     temp[j][n] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
 #endif
                 }
+#ifdef MUL_MAT_ID
+                if (p.enable_scale != 0) {
+                    const uint expert_idx = gl_GlobalInvocationID.y;
+                    temp[j][n] *= FLOAT_TYPE(data_bias[expert_idx]);
+                }
+#endif
                 data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
             }
         }
@@ -203,6 +216,12 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
                     tmpsh[j][n][0] += FLOAT_TYPE(data_bias[j*p.batch_stride_d + d_offset + first_row + n]);
 #endif
                 }
+#ifdef MUL_MAT_ID
+                if (p.enable_scale != 0) {
+                    const uint expert_idx = gl_GlobalInvocationID.y;
+                    tmpsh[j][n][0] *= FLOAT_TYPE(data_bias[expert_idx]);
+                }
+#endif
                 data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]);
             }
         }
index 2470c148d66685348ae7446e7cf67c1a8223de23..21c7e3a8cffc140bd0b2d033ddf7ecc6a14f50cd 100644 (file)
@@ -3557,6 +3557,27 @@ struct test_mul_mat : public test_case {
     }
 };
 
+static void init_mul_mat_id_tensors(ggml_context * ctx, int n_mats) {
+    std::random_device rd;
+    std::default_random_engine rng(rd());
+    for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+        if (t->type == GGML_TYPE_I32) {
+            if (ggml_is_view_op(t->op)) { continue; }
+            // ids
+            for (int64_t r = 0; r < ggml_nrows(t); r++) {
+                std::vector<int32_t> data(t->ne[0]);
+                for (int i = 0; i < t->ne[0]; i++) {
+                    data[i] = i % n_mats;
+                }
+                std::shuffle(data.begin(), data.end(), rng);
+                ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
+            }
+        } else {
+            init_tensor_uniform(t);
+        }
+    }
+}
+
 // GGML_OP_MUL_MAT_ID
 struct test_mul_mat_id : public test_case {
     const ggml_type type_a;
@@ -3567,10 +3588,9 @@ struct test_mul_mat_id : public test_case {
     const int64_t m;
     const int64_t n;
     const int64_t k;
-    const uint32_t o; // number of outputs
 
     std::string vars() override {
-        return VARS_TO_STR9(type_a, type_b, n_mats, n_used, b, m, n, k, o);
+        return VARS_TO_STR8(type_a, type_b, n_mats, n_used, b, m, n, k);
     }
 
     double max_nmse_err() override {
@@ -3584,9 +3604,69 @@ struct test_mul_mat_id : public test_case {
 
     test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
             int n_mats = 8, int n_used = 2, bool b = false,
-            int64_t m = 32, int64_t n = 32, int64_t k = 32, uint32_t o = 1)
+            int64_t m = 32, int64_t n = 32, int64_t k = 32)
         : type_a(type_a), type_b(type_b), n_mats(n_mats), n_used(n_used), b(b),
-            m(m), n(n), k(k), o(o) {
+            m(m), n(n), k(k) {
+            GGML_ASSERT(n_used <= n_mats);
+        }
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        // C^T = A * B^T: (k, m) * (k, n) => (m, n)
+        ggml_tensor * as = ggml_new_tensor_3d(ctx, type_a, k, m, n_mats);
+        ggml_set_name(as, "as");
+
+        ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, n);
+        ggml_set_name(ids, "ids");
+        if (n_used != n_mats) {
+            ids = ggml_view_2d(ctx, ids, n_used, n, ids->nb[1], 0);
+            ggml_set_name(ids, "view_of_ids");
+        }
+
+        ggml_tensor * b = ggml_new_tensor_3d(ctx, type_b, k, this->b ? 1 : n_used, n);
+        ggml_set_name(b, "b");
+
+        ggml_tensor * out = ggml_mul_mat_id(ctx, as, b, ids);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        init_mul_mat_id_tensors(ctx, n_mats);
+    }
+};
+
+// GGML_OP_MUL_MAT_ID + GGML_OP_ADD or GGML_OP_MUL
+struct test_mul_mat_id_fusion : public test_case {
+    const ggml_type type_a;
+    const ggml_type type_b;
+    const int n_mats;
+    const int n_used;
+    const bool b; // broadcast b matrix
+    const int64_t m;
+    const int64_t n;
+    const int64_t k;
+    const uint32_t o; // number of outputs
+    const bool mul;
+
+    std::string vars() override {
+        return VARS_TO_STR10(type_a, type_b, n_mats, n_used, b, m, n, k, o, mul);
+    }
+
+    double max_nmse_err() override {
+        return 5e-4;
+    }
+
+    uint64_t op_flops(ggml_tensor * t) override {
+        GGML_UNUSED(t);
+        return 2 * m * k * n * n_used;
+    }
+
+    test_mul_mat_id_fusion(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
+            int n_mats = 8, int n_used = 2, bool b = false,
+            int64_t m = 32, int64_t n = 32, int64_t k = 32, uint32_t o = 1, bool mul = false)
+        : type_a(type_a), type_b(type_b), n_mats(n_mats), n_used(n_used), b(b),
+            m(m), n(n), k(k), o(o), mul(mul) {
             GGML_ASSERT(n_used <= n_mats);
         }
 
@@ -3615,35 +3695,25 @@ struct test_mul_mat_id : public test_case {
             out = ggml_add(ctx, out, out2);
         }
 
+        if (mul) {
+            std::array<int64_t, 4> ne { 1, out->ne[1], out->ne[2], out->ne[3] };
+            ne[0] = 1;
+            ggml_tensor * m = ggml_new_tensor(ctx, out->type, 4, ne.data());
+            out = ggml_mul(ctx, out, m);
+        }
+
         return out;
     }
 
     void initialize_tensors(ggml_context * ctx) override {
-        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
-            if (t->type == GGML_TYPE_I32) {
-                if (ggml_is_view_op(t->op)) { continue; }
-                std::random_device rd;
-                std::default_random_engine rng(rd());
-                // ids
-                for (int64_t r = 0; r < ggml_nrows(t); r++) {
-                    std::vector<int32_t> data(t->ne[0]);
-                    for (int i = 0; i < t->ne[0]; i++) {
-                        data[i] = i % n_mats;
-                    }
-                    std::shuffle(data.begin(), data.end(), rng);
-                    ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
-                }
-            } else {
-                init_tensor_uniform(t);
-            }
-        }
+        init_mul_mat_id_tensors(ctx, n_mats);
     }
 
-    bool run_whole_graph() override { return o > 1; }
+    bool run_whole_graph() override { return true; }
 
     std::string op_desc(ggml_tensor * t) override {
         GGML_UNUSED(t);
-        return ggml_op_name(GGML_OP_MUL_MAT_ID);
+        return "MUL_MAT_ID_FUSION";
     }
 };
 
@@ -4992,24 +5062,7 @@ struct test_mul_mat_vec_fusion : public test_case {
                 init_tensor_uniform(t);
             }
         } else {
-            std::random_device rd;
-            std::default_random_engine rng(rd());
-            for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
-                if (t->type == GGML_TYPE_I32) {
-                    if (ggml_is_view_op(t->op)) { continue; }
-                    // ids
-                    for (int64_t r = 0; r < ggml_nrows(t); r++) {
-                        std::vector<int32_t> data(t->ne[0]);
-                        for (int i = 0; i < t->ne[0]; i++) {
-                            data[i] = i % n_mats;
-                        }
-                        std::shuffle(data.begin(), data.end(), rng);
-                        ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
-                    }
-                } else {
-                    init_tensor_uniform(t);
-                }
-            }
+            init_mul_mat_id_tensors(ctx, n_mats);
         }
     }
 
@@ -6979,7 +7032,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     }
 
     test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1));
-    test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 32, 32, 32, 3));
+    test_cases.emplace_back(new test_mul_mat_id_fusion(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, false, 32, 32, 32, 3));
 
     // gpt-oss issue with Vulkan mmq_id
     test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880));
@@ -7016,6 +7069,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         }
     }
 
+    for (int bs : {1, 4, 512}) {
+        for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q4_K}) {
+            for (ggml_type type_b : {GGML_TYPE_F32}) {
+                // test with mul after (ffn_moe_weighted)
+                test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 128, 8, false, 768, bs, 2048, 1, true));
+            }
+        }
+    }
+
     for (ggml_type type_a : base_types) {
         for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
             for (int n : {1, 16}) {
@@ -7472,7 +7534,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
     for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) {
         for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) {
             for (ggml_type type_b : {GGML_TYPE_F32}) {
-                test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 128, 8, false, 768, bs, 2048, 1));
+                test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 128, 8, false, 768, bs, 2048, 1));
             }
         }
     }
@@ -7480,7 +7542,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
     for (int bs : {1, 4, 8, 32, 64, 128, 256, 512}) {
         for (ggml_type type_a : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_Q4_K, GGML_TYPE_Q6_K, GGML_TYPE_IQ2_XS}) {
             for (ggml_type type_b : {GGML_TYPE_F32}) {
-                test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 1792, bs, 2048, 1));
+                test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 32, 4, false, 1792, bs, 2048, 1));
             }
         }
     }
@@ -7490,7 +7552,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
     for (int bs : {1, 4, 8, 512}) {
         for (ggml_type type_a : {GGML_TYPE_MXFP4}) {
             for (ggml_type type_b : {GGML_TYPE_F32}) {
-                test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, 32, 4, false, 2880, bs, 2880, 1));
+                test_cases.emplace_back(new test_mul_mat_id_fusion(type_a, type_b, 32, 4, false, 2880, bs, 2880, 1));
             }
         }
     }