]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
vulkan: fix rms_norm+mul fusion (llama/14545)
authorJeff Bolz <redacted>
Sun, 6 Jul 2025 08:08:16 +0000 (03:08 -0500)
committerGeorgi Gerganov <redacted>
Sat, 12 Jul 2025 13:05:00 +0000 (16:05 +0300)
The fused operation was grabbing the epsilon value from the wrong place.

Add an env var to disable fusion.

Add some missing checks for supported shapes/types.

Handle fused rms_norm+mul in check_results.

src/ggml-vulkan/ggml-vulkan.cpp
tests/test-backend-ops.cpp

index e8df00d4183acfa132dcf8f5135247050c5cff5d..7b8faf17879d62a0f1a9a0f315e3a10020f5f216 100644 (file)
@@ -501,6 +501,8 @@ struct vk_device_struct {
 
     ggml_backend_buffer_type buffer_type;
 
+    bool disable_fusion;
+
 #ifdef GGML_VULKAN_MEMORY_DEBUG
     std::unique_ptr<vk_memory_logger> memory_logger;
 #endif
@@ -1091,8 +1093,8 @@ static size_t vk_skip_checks;
 static size_t vk_output_tensor;
 
 static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name);
-static void ggml_vk_check_results_0(ggml_tensor * tensor);
-static void ggml_vk_check_results_1(ggml_tensor * tensor);
+static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
+static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx);
 #endif
 
 typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
@@ -3507,6 +3509,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
 
         device->idx = idx;
 
+        device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr;
+
         return device;
     }
 
@@ -7654,8 +7658,7 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
     ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
 }
 
-static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
-    float * op_params = (float *)dst->op_params;
+static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params, bool dryrun = false) {
     const uint32_t src0_type_size = ggml_type_size(src0->type);
     const uint32_t src1_type_size = ggml_type_size(src1->type);
     const uint32_t dst_type_size = ggml_type_size(dst->type);
@@ -8885,7 +8888,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
     }
 }
 
-static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
+static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
 
 // Returns true if node has enqueued work into the queue, false otherwise
 // If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
@@ -9146,9 +9149,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
             // fused rms_norm + mul
             ggml_tensor *mul = cgraph->nodes[node_idx + 1];
             ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0];
-            ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, dryrun);
+            ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, (float *)node->op_params, dryrun);
         } else {
-            ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, dryrun);
+            ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, (float *)node->op_params, dryrun);
         }
         break;
     case GGML_OP_RMS_NORM_BACK:
@@ -9308,7 +9311,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
 
         ctx->compute_ctx.reset();
 
-        bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false, almost_ready);
+        bool ok = ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, false, almost_ready);
         if (!ok) {
             if (node->op == GGML_OP_UNARY) {
                 std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
@@ -9323,7 +9326,8 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
     return true;
 }
 
-static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
+static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
+    GGML_UNUSED(cgraph);
     ggml_backend_buffer * buf = nullptr;
 
     switch (tensor->op) {
@@ -9433,7 +9437,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
     // Only run if ctx hasn't been submitted yet
     if (!subctx->seqs.empty()) {
 #ifdef GGML_VULKAN_CHECK_RESULTS
-        ggml_vk_check_results_0(tensor);
+        ggml_vk_check_results_0(ctx, cgraph, tensor_idx);
         use_fence = true;
 #endif
 
@@ -9453,7 +9457,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
             ggml_vk_wait_for_fence(ctx);
         }
 #ifdef GGML_VULKAN_CHECK_RESULTS
-        ggml_vk_check_results_1(tensor);
+        ggml_vk_check_results_1(ctx, cgraph, tensor_idx);
 #endif
     }
 
@@ -9900,6 +9904,37 @@ static bool ggml_vk_is_empty(ggml_tensor * node) {
     return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
 }
 
+static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
+    if (!ggml_can_fuse(cgraph, node_idx, ops)) {
+        return false;
+    }
+
+    if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
+        // additional constraints specific to this fusion
+        const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
+        const ggml_tensor *mul = cgraph->nodes[node_idx + 1];
+
+        GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
+        GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
+        // rms_norm only supports f32
+        if (mul->src[0]->type != GGML_TYPE_F32 ||
+            mul->src[1]->type != GGML_TYPE_F32 ||
+            mul->type != GGML_TYPE_F32) {
+            return false;
+        }
+        // if rms_norm is the B operand, then we don't handle broadcast
+        if (rms_norm == mul->src[1] &&
+            mul->src[0]->ne[1] != rms_norm->ne[1]) {
+            return false;
+        }
+        // rms_norm shader assumes contiguous rows
+        if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
+            return false;
+        }
+    }
+    return true;
+}
+
 static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
     VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
     ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
@@ -9913,7 +9948,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
 
     uint64_t total_mat_mul_bytes = 0;
     for (int i = 0; i < cgraph->n_nodes; i++) {
-        if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
+        if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
             ctx->num_additional_fused_ops = 1;
         }
         ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
@@ -9983,7 +10018,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
             mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
         }
 
-        if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
+        if (!ctx->device->disable_fusion && ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
             ctx->num_additional_fused_ops = 1;
         }
 
@@ -10760,11 +10795,21 @@ void * comp_result;
 size_t comp_size;
 size_t comp_nb[GGML_MAX_DIMS];
 size_t check_counter = 0;
-static void ggml_vk_check_results_0(ggml_tensor * tensor) {
+static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
+    ggml_tensor * tensor = cgraph->nodes[tensor_idx];
     if (tensor->op == GGML_OP_TRANSPOSE) {
         return;
     }
 
+    bool fused_rms_norm_mul = false;
+    int rms_norm_idx = -1;
+    if (ctx->num_additional_fused_ops == 1 &&
+        tensor->op == GGML_OP_RMS_NORM &&
+        cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
+        fused_rms_norm_mul = true;
+        tensor = cgraph->nodes[tensor_idx + 1];
+    }
+
     check_counter++;
     if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
         return;
@@ -10792,6 +10837,15 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
 
     for (int i = 0; i < 6; i++) {
         ggml_tensor * srci = tensor->src[i];
+        if (fused_rms_norm_mul) {
+            rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1;
+            ggml_tensor *rms_norm = tensor->src[rms_norm_idx];
+            switch (i) {
+            case 0: srci = rms_norm->src[0]; break;
+            case 1: srci = tensor->src[1 - rms_norm_idx]; break;
+            default: continue;
+            }
+        }
         if (srci == nullptr) {
             continue;
         }
@@ -10849,7 +10903,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
     } else if (tensor->op == GGML_OP_SUB) {
         tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]);
     } else if (tensor->op == GGML_OP_MUL) {
-        tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
+        if (fused_rms_norm_mul) {
+            tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->src[rms_norm_idx]->op_params);
+            tensor_clone = ggml_mul(ggml_ctx, tensor_clone, src_clone[1 - rms_norm_idx]);
+        } else {
+            tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]);
+        }
     } else if (tensor->op == GGML_OP_DIV) {
         tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]);
     } else if (tensor->op == GGML_OP_CONCAT) {
@@ -11040,10 +11099,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
         GGML_ABORT("fatal error");
     }
 
-    ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx);
-    ggml_build_forward_expand(cgraph, tensor_clone);
+    ggml_cgraph * cgraph_cpu = ggml_new_graph(ggml_ctx);
+    ggml_build_forward_expand(cgraph_cpu, tensor_clone);
 
-    ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 8);
+    ggml_graph_compute_with_ctx(ggml_ctx, cgraph_cpu, 8);
 
     if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
         ggml_vk_print_tensor(tensor_clone, "tensor_clone");
@@ -11066,10 +11125,19 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
     VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")");
 }
 
-static void ggml_vk_check_results_1(ggml_tensor * tensor) {
+static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) {
+    ggml_tensor * tensor = cgraph->nodes[tensor_idx];
     if (tensor->op == GGML_OP_TRANSPOSE) {
         return;
     }
+    bool fused_rms_norm_mul = false;
+    if (ctx->num_additional_fused_ops == 1 &&
+        tensor->op == GGML_OP_RMS_NORM &&
+        cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) {
+        fused_rms_norm_mul = true;
+        tensor = cgraph->nodes[tensor_idx + 1];
+    }
+
     if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) {
         return;
     }
index 0dc9c09e28ee2b106dc57469131f5ce8c6288dce..652856a35d5e7b96a56ad964054b5c607fc79c37 100644 (file)
@@ -2583,10 +2583,6 @@ struct test_rms_norm_mul : public test_case {
         }
     }
 
-    double max_nmse_err() override {
-        return 1e-6;
-    }
-
     float grad_eps() override {
         return 1.0f;
     }
@@ -5058,7 +5054,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
         test_cases.emplace_back(new test_l2_norm      (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
     }
-    for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) {
+    for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) {
         test_cases.emplace_back(new test_rms_norm_mul(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
     }