]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
vulkan: Add fusion support for RMS_NORM+MUL (llama/14366)
authorJeff Bolz <redacted>
Sun, 29 Jun 2025 07:43:36 +0000 (02:43 -0500)
committerGeorgi Gerganov <redacted>
Tue, 1 Jul 2025 14:54:53 +0000 (17:54 +0300)
* vulkan: Add fusion support for RMS_NORM+MUL

- Add a use_count to ggml_tensor, so we can detect if an output is used more than once.
- Change the ggml-vulkan rms_norm shader to optionally multiply by another tensor.
- Add detection logic and basic fusion logic in ggml-vulkan.
- Add some testing support for fusion. Rather than computing one node at a time, allow
for computing the whole graph and just testing one node's results. Add rms_norm_mul tests
and enable a llama test.

* extract some common fusion logic

* fix -Winconsistent-missing-override

* move ggml_can_fuse to a common function

* build fix

* C and C++ versions of can_fuse

* move use count to the graph to avoid data races and double increments when used in multiple threads

* use hash table lookup to find node index

* change use_counts to be indexed by hash table slot

* minimize hash lookups

style fixes

* last node doesn't need single use.
fix type.
handle mul operands being swapped.

* remove redundant parameter

---------

Co-authored-by: slaren <redacted>
ggml/include/ggml-backend.h
ggml/src/ggml-backend.cpp
ggml/src/ggml-impl.h
ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
ggml/src/ggml.c

index 778927f68217ae2eaad89b080a3f03829b4321e0..a2977ea2e56d935100b03d5f4e183647a6f1e4d2 100644 (file)
@@ -339,7 +339,7 @@ extern "C" {
     typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
 
     // Compare the output of two backends
-    GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data);
+    GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node);
 
     // Tensor initialization
     GGML_API enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
index b1050ad59c26a8d931dbeb8031787a3bcf1c7a07..788861a365fab65e234d469f3730b7bf970d6b7e 100644 (file)
@@ -817,8 +817,9 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str
         }
         if (sched->debug > 1) {
             ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node);
-            GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s]:", i, ggml_op_name(node->op), node->name,
-                fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node));
+            GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d:", i, ggml_op_name(node->op), node->name,
+                fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node),
+                graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)]);
             for (int j = 0; j < GGML_MAX_SRC; j++) {
                 struct ggml_tensor * src = node->src[j];
                 if (src == NULL) {
@@ -1826,7 +1827,7 @@ void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy) {
     ggml_free(copy.ctx_unallocated);
 }
 
-bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data) {
+bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node) {
     struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph);
     if (copy.buffer == NULL) {
         return false;
@@ -1837,28 +1838,45 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t
 
     assert(g1->n_nodes == g2->n_nodes);
 
-    for (int i = 0; i < g1->n_nodes; i++) {
-        struct ggml_tensor * t1 = g1->nodes[i];
-        struct ggml_tensor * t2 = g2->nodes[i];
+    if (test_node != nullptr) {
+        // Compute the whole graph and only test the output for a specific tensor
+        ggml_backend_graph_compute(backend1, g1);
+        ggml_backend_graph_compute(backend2, g2);
 
-        assert(t1->op == t2->op && ggml_are_same_layout(t1, t2));
+        int test_node_idx = -1;
+        for (int i = 0; i < g1->n_nodes; i++) {
+            struct ggml_tensor * t1 = g1->nodes[i];
+            if (t1 == test_node) {
+                test_node_idx = i;
+                break;
+            }
+        }
+        GGML_ASSERT(test_node_idx != -1);
 
-        struct ggml_cgraph g1v = ggml_graph_view(g1, i, i + 1);
-        struct ggml_cgraph g2v = ggml_graph_view(g2, i, i + 1);
+        callback(test_node_idx, g1->nodes[test_node_idx], g2->nodes[test_node_idx], user_data);
+    } else {
+        for (int i = 0; i < g1->n_nodes; i++) {
+            struct ggml_tensor * t1 = g1->nodes[i];
+            struct ggml_tensor * t2 = g2->nodes[i];
 
-        ggml_backend_graph_compute(backend1, &g1v);
-        ggml_backend_graph_compute(backend2, &g2v);
+            assert(t1->op == t2->op && ggml_are_same_layout(t1, t2));
 
-        if (ggml_is_view_op(t1->op)) {
-            continue;
-        }
+            struct ggml_cgraph g1v = ggml_graph_view(g1, i, i + 1);
+            struct ggml_cgraph g2v = ggml_graph_view(g2, i, i + 1);
 
-        // compare results, calculate rms etc
-        if (!callback(i, t1, t2, user_data)) {
-            break;
+            ggml_backend_graph_compute(backend1, &g1v);
+            ggml_backend_graph_compute(backend2, &g2v);
+
+            if (ggml_is_view_op(t1->op)) {
+                continue;
+            }
+
+            // compare results, calculate rms etc
+            if (!callback(i, t1, t2, user_data)) {
+                break;
+            }
         }
     }
-
     ggml_backend_graph_copy_free(copy);
 
     return true;
index 57761644f431a531f3c321eecbf7e725ec0c7b03..4972558c98b81b46865cf2d708469a1243181863 100644 (file)
@@ -301,6 +301,7 @@ struct ggml_cgraph {
     struct ggml_tensor ** grads;     // the outputs of these tensors are the gradients of the nodes
     struct ggml_tensor ** grad_accs; // accumulators for node gradients
     struct ggml_tensor ** leafs;     // tensors with constant data
+    int32_t             * use_counts;// number of uses of each tensor, indexed by hash table slot
 
     struct ggml_hash_set visited_hash_set;
 
@@ -467,13 +468,76 @@ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
 #define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
 #define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
 
+// return true if the node's results are only used by N other nodes
+// and can be fused into their calculations.
+static inline bool ggml_node_has_n_uses(const struct ggml_cgraph * cgraph, int node_idx, int32_t n_uses) {
+    const struct ggml_tensor * node = cgraph->nodes[node_idx];
+
+    // check the use count against how many we're replacing
+    size_t hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
+    if (!ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos) || cgraph->use_counts[hash_pos] != n_uses) {
+        return false;
+    }
+
+    // if node is a view, some other node might be using the intermediate result
+    // via the view source.
+    if (node->view_src) {
+        return false;
+    }
+
+    // If the user requested output for the node, can't fuse
+    if (node->flags & GGML_TENSOR_FLAG_OUTPUT) {
+        return false;
+    }
+
+    return true;
+}
+
+// Returns true if nodes [i, i+ops.size()) are the sequence of ggml_ops in ops[]
+// and are fusable. Nodes are considered fusable according to this function if:
+// - all nodes except the last have only one use and are not views/outputs (see ggml_node_has_N_uses).
+// - all nodes except the last are a src of the following node.
+// - all nodes are the same shape.
+// TODO: Consider allowing GGML_OP_NONE nodes in between
+static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, const enum ggml_op * ops, int num_ops) {
+    if (node_idx + num_ops > cgraph->n_nodes) {
+        return false;
+    }
+
+    for (int i = 0; i < num_ops; ++i) {
+        struct ggml_tensor * node = cgraph->nodes[node_idx + i];
+        if (node->op != ops[i]) {
+            return false;
+        }
+        if (i < num_ops - 1 && !ggml_node_has_n_uses(cgraph, node_idx + i, 1)) {
+            return false;
+        }
+        if (i > 0) {
+            struct ggml_tensor * prev = cgraph->nodes[node_idx + i - 1];
+            if (node->src[0] != prev && node->src[1] != prev) {
+                return false;
+            }
+            if (!ggml_are_same_shape(node, prev)) {
+                return false;
+            }
+        }
+    }
+    return true;
+}
+
 #ifdef __cplusplus
 }
 #endif
 
 #ifdef __cplusplus
+#include <initializer_list>
 #include <vector>
 
+// nicer C++ syntax for ggml_can_fuse
+inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
+    return ggml_can_fuse(cgraph, node_idx, ops.begin(), (int)ops.size());
+}
+
 // expose GGUF internals for test code
 GGML_API size_t gguf_type_size(enum gguf_type type);
 GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
index 996ccbf66b016f199601ba9301422e0a87374a15..aebcc03915f5f71145d66fe159902485ef3df8d7 100644 (file)
@@ -425,6 +425,7 @@ struct vk_device_struct {
     vk_pipeline pipeline_norm_f32;
     vk_pipeline pipeline_group_norm_f32;
     vk_pipeline pipeline_rms_norm_f32;
+    vk_pipeline pipeline_rms_norm_mul_f32;
     vk_pipeline pipeline_rms_norm_back_f32;
     vk_pipeline pipeline_l2_norm_f32;
 
@@ -978,6 +979,10 @@ struct ggml_backend_vk_context {
 
     vk_command_pool compute_cmd_pool;
     vk_command_pool transfer_cmd_pool;
+
+    // number of additional consecutive nodes that are being fused with the
+    // node currently being processed
+    uint32_t num_additional_fused_ops {};
 };
 
 static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000;  // NOLINT
@@ -2655,7 +2660,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
     ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
 
@@ -6430,7 +6436,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
         return nullptr;
     case GGML_OP_RMS_NORM:
         if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return ctx->device->pipeline_rms_norm_f32;
+            return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
         }
         return nullptr;
     case GGML_OP_RMS_NORM_BACK:
@@ -7530,18 +7536,19 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
     ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
 }
 
-static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
+static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
     float * op_params = (float *)dst->op_params;
     const uint32_t src0_type_size = ggml_type_size(src0->type);
+    const uint32_t src1_type_size = ggml_type_size(src1->type);
     const uint32_t dst_type_size = ggml_type_size(dst->type);
 
-    ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, {
+    ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, {
         (uint32_t)ggml_nelements(src0),
-        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
-        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
+        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+        (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
+        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
         0,
-        op_params[0], 0.0f,
-        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+        op_params[0], 0.0f, 0,
     }, dryrun);
 }
 
@@ -8736,7 +8743,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* t
 
 // Returns true if node has enqueued work into the queue, false otherwise
 // If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
-static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){
+static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){
+    ggml_tensor * node = cgraph->nodes[node_idx];
     if (ggml_is_empty(node) || !node->buffer) {
         return false;
     }
@@ -8974,8 +8982,14 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
 
         break;
     case GGML_OP_RMS_NORM:
-        ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun);
-
+        if (ctx->num_additional_fused_ops > 0) {
+            // fused rms_norm + mul
+            ggml_tensor *mul = cgraph->nodes[node_idx + 1];
+            ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0];
+            ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, dryrun);
+        } else {
+            ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, dryrun);
+        }
         break;
     case GGML_OP_RMS_NORM_BACK:
         ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -9710,10 +9724,15 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
 
     uint64_t total_mat_mul_bytes = 0;
     for (int i = 0; i < cgraph->n_nodes; i++) {
-        ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false, false);
+        if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
+            ctx->num_additional_fused_ops = 1;
+        }
+        ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
         if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
             total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
         }
+        i += ctx->num_additional_fused_ops;
+        ctx->num_additional_fused_ops = 0;
     }
     if (ctx->device->need_compiles) {
         ggml_vk_load_shaders(ctx->device);
@@ -9775,14 +9794,18 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
             mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
         }
 
+        if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
+            ctx->num_additional_fused_ops = 1;
+        }
+
         // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
         bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
         bool submit = (submitted_nodes >= nodes_per_submit) ||
                       (mul_mat_bytes >= mul_mat_bytes_per_submit) ||
-                      (i == last_node) ||
+                      (i + ctx->num_additional_fused_ops == last_node) ||
                       (almost_ready && !ctx->almost_ready_fence_pending);
 
-        bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, almost_ready, submit);
+        bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops == last_node, almost_ready, submit);
 
         if (vk_perf_logger_enabled) {
             if (ctx->compute_ctx.expired()) {
@@ -9792,7 +9815,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
             } else {
                 compute_ctx = ctx->compute_ctx.lock();
             }
-            compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+1);
+            // If there are fused ops, just write out timestamps for all nodes to keep the accounting simple
+            for (int j = 0; j < ctx->num_additional_fused_ops + 1; ++j) {
+                compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+j+1);
+            }
         }
 
         if (enqueued) {
@@ -9814,6 +9840,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
             }
             submit_count++;
         }
+        i += ctx->num_additional_fused_ops;
+        ctx->num_additional_fused_ops = 0;
     }
 
     if (vk_perf_logger_enabled) {
index deb8ee9960f58240fb5ea33b06809ffb2f4632e9..6428ca7ba330098cc916e1b74733b22608edf634 100644 (file)
@@ -1,11 +1,13 @@
 #version 450
 
-#include "generic_unary_head.comp"
+#include "generic_binary_head.comp"
 #include "types.comp"
 
 #extension GL_EXT_control_flow_attributes : enable
 #define BLOCK_SIZE 512
 
+layout (constant_id = 1) const bool do_multiply = false;
+
 layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
 
 shared FLOAT_TYPE sum[BLOCK_SIZE];
@@ -25,6 +27,7 @@ void main() {
     const uint stride_sample    = p.nb03;
 
     uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
+    uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
     uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
 
     sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
@@ -46,7 +49,13 @@ void main() {
     const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols);
     const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
 
-    [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
-        data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
+    if (do_multiply) {
+        [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
+            data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
+        }
+    } else {
+        [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
+            data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
+        }
     }
 }
index c63345ec8b4b673ba581f8ad1c7cfe64900e9df0..a207b98c60e510ef55c1a448bf4fe888bb03e90c 100644 (file)
@@ -497,7 +497,7 @@ void process_shaders() {
     // Norms
     string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
     string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
-    string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
+    string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
     string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
     string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
 
index be8a9c4488b14648640b7e8abd02529ea705ad94..e70d9d57b3bbea428a39f3a7afc85579e782cd8a 100644 (file)
@@ -5850,19 +5850,32 @@ static void ggml_compute_backward(
     GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(src2, cgraph->grads[isrc2]));
 }
 
-static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
+static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
     // check if already visited
-    if (ggml_hash_insert(&cgraph->visited_hash_set, node) == GGML_HASHSET_ALREADY_EXISTS) {
-        return;
+    size_t node_hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
+    GGML_ASSERT(node_hash_pos != GGML_HASHSET_FULL);
+    if (!ggml_bitset_get(cgraph->visited_hash_set.used, node_hash_pos)) {
+        // This is the first time we see this node in the current graph.
+        cgraph->visited_hash_set.keys[node_hash_pos] = node;
+        ggml_bitset_set(cgraph->visited_hash_set.used, node_hash_pos);
+        cgraph->use_counts[node_hash_pos] = 0;
+    } else {
+        // already visited
+        return node_hash_pos;
     }
 
     for (int i = 0; i < GGML_MAX_SRC; ++i) {
         const int k =
             (cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i :
             (cgraph->order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (GGML_MAX_SRC-1-i) :
-            /* unknown order, just fall back to using i*/ i;
-        if (node->src[k]) {
-            ggml_visit_parents(cgraph, node->src[k]);
+            /* unknown order, just fall back to using i */ i;
+
+        struct ggml_tensor * src = node->src[k];
+        if (src) {
+            size_t src_hash_pos = ggml_visit_parents(cgraph, src);
+
+            // Update the use count for this operand.
+            cgraph->use_counts[src_hash_pos]++;
         }
     }
 
@@ -5886,6 +5899,8 @@ static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor *
         cgraph->nodes[cgraph->n_nodes] = node;
         cgraph->n_nodes++;
     }
+
+    return node_hash_pos;
 }
 
 static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) {
@@ -6023,6 +6038,7 @@ static size_t ggml_graph_nbytes(size_t size, bool grads) {
     incr_ptr_aligned(&p, sizeof(struct ggml_cgraph), 1);
     incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // nodes
     incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // leafs
+    incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t)); // use_counts
     incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // hash keys
     if (grads) {
         incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads
@@ -6052,11 +6068,12 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
 
     void * p = cgraph + 1;
 
-    struct ggml_tensor ** nodes_ptr     =         incr_ptr_aligned(&p, size      * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
-    struct ggml_tensor ** leafs_ptr     =         incr_ptr_aligned(&p, size      * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
-    struct ggml_tensor ** hash_keys_ptr =         incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
-    struct ggml_tensor ** grads_ptr     = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
-    struct ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
+    struct ggml_tensor ** nodes_ptr      =         incr_ptr_aligned(&p, size      * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
+    struct ggml_tensor ** leafs_ptr      =         incr_ptr_aligned(&p, size      * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
+    int32_t             * use_counts_ptr =         incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t));
+    struct ggml_tensor ** hash_keys_ptr  =         incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
+    struct ggml_tensor ** grads_ptr      = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
+    struct ggml_tensor ** grad_accs_ptr  = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
 
     ggml_bitset_t * hash_used = incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t));
 
@@ -6071,6 +6088,7 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz
         /*.grads        =*/ grads_ptr,
         /*.grad_accs    =*/ grad_accs_ptr,
         /*.leafs        =*/ leafs_ptr,
+        /*.use_counts   =*/ use_counts_ptr,
         /*.hash_table   =*/ { hash_size, hash_used, hash_keys_ptr },
         /*.order        =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
     };
@@ -6097,7 +6115,8 @@ struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1)
         /*.grads            =*/ NULL, // gradients would need visited_hash_set
         /*.grad_accs        =*/ NULL,
         /*.leafs            =*/ NULL,
-        /*.visited_hash_set =*/ { 0, NULL, NULL },
+        /*.use_counts       =*/ cgraph0->use_counts,
+        /*.visited_hash_set =*/ cgraph0->visited_hash_set,
         /*.order            =*/ cgraph0->order,
     };
 
@@ -6124,7 +6143,8 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) {
     for (size_t i = 0; i < src->visited_hash_set.size; ++i) {
         // copy all hashset keys (tensors) that are in use
         if (ggml_bitset_get(src->visited_hash_set.used, i)) {
-            ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]);
+            size_t new_hash_pos = ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]);
+            dst->use_counts[new_hash_pos] = src->use_counts[i];
         }
     }