]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
metal : allow ops to run concurrently (llama/15929)
authorGeorgi Gerganov <redacted>
Sat, 13 Sep 2025 10:54:28 +0000 (13:54 +0300)
committerGeorgi Gerganov <redacted>
Sat, 20 Sep 2025 10:33:50 +0000 (13:33 +0300)
* metal : run graphs ops concurrently

ggml-ci

* cont : add flags for debugging and disabling concurrency

ggml-ci

* cont : refactor and handle fusing

ggml-ci

* cont : simplify - no need to use GPU address

ggml-ci

* cont : prepare mem ranges for reuse + add ggml-metal-common.cpp

ggml-ci

* cont : avoid redundant keywords in cpp [no ci]

* metal : reorder graph for better concurrency

ggml-ci

* metal : fix race on mem pool buffers

ggml-ci

* cont : add env GGML_METAL_GRAPH_OPTIMIZE_DISABLE

ggml-ci

* cont : refactor, optimize, add comments

ggml-ci

* cont : refactor ggml-metal.m

ggml-ci

* minor : update logs [no ci]

src/ggml-metal/CMakeLists.txt
src/ggml-metal/ggml-metal-common.cpp [new file with mode: 0644]
src/ggml-metal/ggml-metal-common.h [new file with mode: 0644]
src/ggml-metal/ggml-metal.m

index 0ca8a3c55ec4424f6d433d91f351bda398f9c3c9..65c131b6216874581dca766b2a03796c57b28907 100644 (file)
@@ -6,6 +6,7 @@ message(STATUS "Metal framework found")
 
 ggml_add_backend_library(ggml-metal
                          ggml-metal.m
+                         ggml-metal-common.cpp
                         )
 
 target_link_libraries(ggml-metal PRIVATE
diff --git a/src/ggml-metal/ggml-metal-common.cpp b/src/ggml-metal/ggml-metal-common.cpp
new file mode 100644 (file)
index 0000000..6a869ff
--- /dev/null
@@ -0,0 +1,445 @@
+#include "ggml-metal-common.h"
+
+#include "ggml-impl.h"
+
+#include <vector>
+
+struct ggml_mem_range {
+    uint64_t pb; // buffer id
+
+    uint64_t p0; // begin
+    uint64_t p1; // end
+
+    ggml_mem_range_type pt;
+};
+
+struct ggml_mem_ranges {
+    std::vector<ggml_mem_range> ranges;
+
+    int debug = 0;
+};
+
+struct ggml_mem_ranges * ggml_mem_ranges_init(int debug) {
+    auto * res = new ggml_mem_ranges;
+
+    res->ranges.reserve(256);
+    res->debug = debug;
+
+    return res;
+}
+
+void ggml_mem_ranges_free(ggml_mem_ranges * mrs) {
+    delete mrs;
+}
+
+void ggml_mem_ranges_reset(ggml_mem_ranges * mrs) {
+    mrs->ranges.clear();
+}
+
+static bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, ggml_mem_range mrp) {
+    mrs->ranges.push_back(mrp);
+
+    return true;
+}
+
+static ggml_mem_range ggml_mem_range_from_tensor(const ggml_tensor * tensor, ggml_mem_range_type pt) {
+    // always use the base tensor
+    tensor = tensor->view_src ? tensor->view_src : tensor;
+
+    GGML_ASSERT(!tensor->view_src);
+
+    ggml_mem_range mrp;
+
+    if (tensor->buffer) {
+        // when the tensor is allocated, use the actual memory address range of the buffer
+        mrp = {
+            /*.pb =*/ (uint64_t) tensor->buffer,
+            /*.p0 =*/ (uint64_t) tensor->data,
+            /*.p1 =*/ (uint64_t) tensor->data + ggml_nbytes(tensor),
+            /*.pt =*/ pt,
+        };
+    } else {
+        // otherwise, the tensor ptr is used as an unique id of the memory ranges
+        //   that the tensor will be using when it is allocated
+        mrp = {
+            /*.pb =*/ (uint64_t) tensor,
+            /*.p0 =*/ 0,    //
+            /*.p1 =*/ 1024, // [0, 1024) is a dummy range, not used
+            /*.pt =*/ pt,
+        };
+    };
+
+    return mrp;
+}
+
+static ggml_mem_range ggml_mem_range_from_tensor_src(const ggml_tensor * tensor) {
+    return ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_SRC);
+}
+
+static ggml_mem_range ggml_mem_range_from_tensor_dst(const ggml_tensor * tensor) {
+    return ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_DST);
+}
+
+static bool ggml_mem_ranges_add_src(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
+    GGML_ASSERT(tensor);
+
+    ggml_mem_range mrp = ggml_mem_range_from_tensor_src(tensor);
+
+    if (mrs->debug > 2) {
+        GGML_LOG_DEBUG("%s: add src range buf=%lld, [%lld, %lld)\n", __func__, mrp.pb, mrp.p0, mrp.p1);
+    }
+
+    return ggml_mem_ranges_add(mrs, mrp);
+}
+
+static bool ggml_mem_ranges_add_dst(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
+    GGML_ASSERT(tensor);
+
+    ggml_mem_range mrp = ggml_mem_range_from_tensor_dst(tensor);
+
+    if (mrs->debug > 2) {
+        GGML_LOG_DEBUG("%s: add dst range buf=%lld, [%lld, %lld)\n", __func__, mrp.pb, mrp.p0, mrp.p1);
+    }
+
+    return ggml_mem_ranges_add(mrs, mrp);
+}
+
+bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
+    for (int i = 0; i < GGML_MAX_DIMS; i++) {
+        if (tensor->src[i]) {
+            ggml_mem_ranges_add_src(mrs, tensor->src[i]);
+        }
+    }
+
+    return ggml_mem_ranges_add_dst(mrs, tensor);
+}
+
+static bool ggml_mem_ranges_check(const ggml_mem_ranges * mrs, ggml_mem_range mrp) {
+    for (size_t i = 0; i < mrs->ranges.size(); i++) {
+        const auto & cmp = mrs->ranges[i];
+
+        if (mrp.pb != cmp.pb) {
+            continue;
+        }
+
+        if (mrp.pt == MEM_RANGE_TYPE_SRC && cmp.pt == MEM_RANGE_TYPE_SRC) {
+            continue;
+        }
+
+        if (mrp.p0 < cmp.p1 && mrp.p1 >= cmp.p0) {
+            if (mrs->debug > 2) {
+                GGML_LOG_DEBUG("%s: the %s range buf=%lld, [%lld, %lld) overlaps with a previous %s range buf=%lld, [%lld, %lld)\n",
+                        __func__,
+                        mrp.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst",
+                        mrp.pb, mrp.p0, mrp.p1,
+                        cmp.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst",
+                        cmp.pb, cmp.p0, cmp.p1);
+            }
+
+            return false;
+        }
+    }
+
+    return true;
+}
+
+static bool ggml_mem_ranges_check_src(const ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
+    GGML_ASSERT(tensor);
+
+    ggml_mem_range mrp = ggml_mem_range_from_tensor_src(tensor);
+
+    const bool res = ggml_mem_ranges_check(mrs, mrp);
+
+    return res;
+}
+
+static bool ggml_mem_ranges_check_dst(const ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
+    GGML_ASSERT(tensor);
+
+    ggml_mem_range mrp = ggml_mem_range_from_tensor_dst(tensor);
+
+    const bool res = ggml_mem_ranges_check(mrs, mrp);
+
+    return res;
+}
+
+bool ggml_mem_ranges_check(const ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
+    for (int i = 0; i < GGML_MAX_DIMS; i++) {
+        if (tensor->src[i]) {
+            if (!ggml_mem_ranges_check_src(mrs, tensor->src[i])) {
+                return false;
+            }
+        }
+    }
+
+    return ggml_mem_ranges_check_dst(mrs, tensor);
+}
+
+// TODO: move to ggml.h?
+static bool is_empty(ggml_op op) {
+    switch (op) {
+        case GGML_OP_NONE:
+        case GGML_OP_RESHAPE:
+        case GGML_OP_TRANSPOSE:
+        case GGML_OP_VIEW:
+        case GGML_OP_PERMUTE:
+            return true;
+        default:
+            return false;
+    }
+}
+
+struct node_info {
+    ggml_tensor * node;
+
+    std::vector<ggml_tensor *> fused;
+
+    ggml_op op() const {
+        return node->op;
+    }
+
+    const ggml_tensor * dst() const {
+        return fused.empty() ? node : fused.back();
+    }
+
+    bool is_empty() const {
+        return ::is_empty(node->op);
+    }
+
+    void add_fused(ggml_tensor * t) {
+        fused.push_back(t);
+    }
+};
+
+static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node_info> & nodes) {
+    // helper to add node src and dst ranges
+    const auto & h_add = [](ggml_mem_ranges * mrs, const node_info & node) {
+        for (int i = 0; i < GGML_MAX_SRC; i++) {
+            if (node.node->src[i]) {
+                if (!ggml_mem_ranges_add_src(mrs, node.node->src[i])) {
+                    return false;
+                }
+            }
+        }
+
+        for (const auto * fused : node.fused) {
+            for (int i = 0; i < GGML_MAX_SRC; i++) {
+                if (fused->src[i]) {
+                    if (!ggml_mem_ranges_add_src(mrs, fused->src[i])) {
+                        return false;
+                    }
+                }
+            }
+        }
+
+        return ggml_mem_ranges_add_dst(mrs, node.dst());
+    };
+
+    // helper to check if a node can run concurrently with the existing set of nodes
+    const auto & h_check = [](const ggml_mem_ranges * mrs, const node_info & node) {
+        for (int i = 0; i < GGML_MAX_SRC; i++) {
+            if (node.node->src[i]) {
+                if (!ggml_mem_ranges_check_src(mrs, node.node->src[i])) {
+                    return false;
+                }
+            }
+        }
+
+        for (const auto * fused : node.fused) {
+            for (int i = 0; i < GGML_MAX_SRC; i++) {
+                if (fused->src[i]) {
+                    if (!ggml_mem_ranges_check_src(mrs, fused->src[i])) {
+                        return false;
+                    }
+                }
+            }
+        }
+
+        return ggml_mem_ranges_check_dst(mrs, node.dst());
+    };
+
+    // perform reorders only across these types of ops
+    // can be expanded when needed
+    // IMPORTANT: do not add ops such as GGML_OP_CPY or GGML_OP_SET_ROWS
+    //            the dependencies from such ops are not always represented in the graph
+    const auto & h_safe = [](ggml_op op) {
+        switch (op) {
+            case GGML_OP_MUL_MAT:
+            case GGML_OP_MUL_MAT_ID:
+            case GGML_OP_ROPE:
+            case GGML_OP_NORM:
+            case GGML_OP_RMS_NORM:
+            case GGML_OP_GROUP_NORM:
+            case GGML_OP_SUM_ROWS:
+            case GGML_OP_MUL:
+            case GGML_OP_ADD:
+            case GGML_OP_DIV:
+            case GGML_OP_GLU:
+            case GGML_OP_SCALE:
+            case GGML_OP_GET_ROWS:
+                return true;
+            default:
+                return is_empty(op);
+        }
+    };
+
+    const int n = nodes.size();
+
+    std::vector<int> res;
+    res.reserve(n);
+
+    std::vector<bool> used(n, false);
+
+    ggml_mem_ranges * mrs0 = ggml_mem_ranges_init(0);
+    ggml_mem_ranges * mrs1 = ggml_mem_ranges_init(0);
+
+    for (int i0 = 0; i0 < n; i0++) {
+        if (used[i0]) {
+            continue;
+        }
+
+        const auto & node0 = nodes[i0];
+
+        // the node is not concurrent with the existing concurrent set, so we have to "put a barrier" (i.e reset mrs0)
+        // but before we do that, look forward for some other nodes that can be added to the concurrent set mrs0
+        //
+        // note: we can always add empty nodes to the concurrent set as they don't read nor write anything
+        if (!node0.is_empty() && !h_check(mrs0, node0)) {
+            // this will hold the set of memory ranges from the nodes that haven't been processed yet
+            // if a node is not concurrent with this set, we cannot reorder it
+            ggml_mem_ranges_reset(mrs1);
+
+            // initialize it with the current node
+            h_add(mrs1, node0);
+
+            // that many nodes forward to search for a concurrent node
+            constexpr int N_FORWARD = 8;
+
+            for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) {
+                if (used[i1]) {
+                    continue;
+                }
+
+                const auto & node1 = nodes[i1];
+
+                // disallow reordering of certain ops
+                if (!h_safe(node1.op())) {
+                    break;
+                }
+
+                const bool is_empty = node1.is_empty();
+
+                // to add a concurrent node, it has to be:
+                //   + empty or concurrent with all nodes in the existing concurrent set (mrs0)
+                //   + concurrent with all nodes prior to it that haven't been processed yet (mrs1)
+                if ((is_empty || h_check(mrs0, node1)) && h_check(mrs1, node1)) {
+                    // add the node to the existing concurrent set (i.e. reorder it for early execution)
+                    h_add(mrs0, node1);
+                    res.push_back(i1);
+
+                    // mark as used, so we skip re-processing it later
+                    used[i1] = true;
+                } else {
+                    // expand the set of nodes that haven't been processed yet
+                    h_add(mrs1, node1);
+                }
+            }
+
+            // finalize the concurrent set and begin a new one
+            ggml_mem_ranges_reset(mrs0);
+        }
+
+        // expand the concurrent set with the current node
+        {
+            h_add(mrs0, node0);
+            res.push_back(i0);
+        }
+    }
+
+    ggml_mem_ranges_free(mrs0);
+    ggml_mem_ranges_free(mrs1);
+
+    return res;
+}
+
+void ggml_metal_graph_optimize(ggml_cgraph * gf) {
+    constexpr int MAX_FUSE = 16;
+
+    const int n = gf->n_nodes;
+
+    enum ggml_op ops[MAX_FUSE];
+
+    std::vector<node_info> nodes;
+    nodes.reserve(gf->n_nodes);
+
+    // fuse nodes:
+    // we don't want to make reorders that break fusing, so we first pack all fusable tensors
+    //   and perform the reorder over the fused nodes. after the reorder is done, we unfuse
+    for (int i = 0; i < n; i++) {
+        node_info node = {
+            /*.node =*/ gf->nodes[i],
+            /*.fused =*/ {},
+        };
+
+        // fuse only ops that start with these operations
+        // can be expanded when needed
+        if (node.op() == GGML_OP_ADD ||
+            node.op() == GGML_OP_RMS_NORM) {
+            ops[0] = node.op();
+
+            int f = i + 1;
+            while (f < n && f < i + MAX_FUSE) {
+                // conservatively allow fusing only these ops
+                // can be expanded when needed
+                if (gf->nodes[f]->op != GGML_OP_ADD &&
+                    gf->nodes[f]->op != GGML_OP_MUL &&
+                    gf->nodes[f]->op != GGML_OP_RMS_NORM) {
+                    break;
+                }
+                ops[f - i] = gf->nodes[f]->op;
+                f++;
+            }
+
+            f -= i;
+            for (; f > 1; f--) {
+                if (ggml_can_fuse(gf, i, ops, f)) {
+                    break;
+                }
+            }
+
+            // add the fused tensors into the node info so we can unfuse them later
+            for (int k = 1; k < f; k++) {
+                ++i;
+
+                // the .dst() becomes the last fused tensor
+                node.add_fused(gf->nodes[i]);
+            }
+        }
+
+        nodes.push_back(std::move(node));
+    }
+
+    // reorder to improve concurrency
+#if 1
+    const auto order = ggml_metal_graph_optimize_reorder(nodes);
+#else
+    std::vector<int> order(nodes.size());
+    for (size_t i = 0; i < nodes.size(); i++) {
+        order[i] = i;
+    }
+#endif
+
+    // unfuse
+    {
+        int j = 0;
+        for (const auto i : order) {
+            const auto & node = nodes[i];
+
+            gf->nodes[j++] = node.node;
+
+            for (auto * fused : node.fused) {
+                gf->nodes[j++] = fused;
+            }
+        }
+    }
+}
diff --git a/src/ggml-metal/ggml-metal-common.h b/src/ggml-metal/ggml-metal-common.h
new file mode 100644 (file)
index 0000000..c140289
--- /dev/null
@@ -0,0 +1,52 @@
+// helper functions for ggml-metal that are too difficult to implement in Objective-C
+
+#pragma once
+
+#include <stdbool.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+struct ggml_tensor;
+struct ggml_cgraph;
+
+enum ggml_mem_range_type {
+    MEM_RANGE_TYPE_SRC = 0,
+    MEM_RANGE_TYPE_DST = 1,
+};
+
+// a helper object that can be used for reordering operations to improve concurrency
+//
+// the fundamental idea is that a set of tasks (either ggml ops, or something else) can run concurrently if they
+//   don't write to a memory that is being read by another task or written to by another task in the set
+//
+// with this structure, we can add tasks to the set, setting memory constraints. we can also check if a new task
+//   can be added to the set without violating the constraints (i.e. if it can be executed concurrently with the
+//   tasks already in the set)
+//
+struct ggml_mem_ranges;
+
+struct ggml_mem_ranges * ggml_mem_ranges_init(int debug);
+void ggml_mem_ranges_free(struct ggml_mem_ranges * mrs);
+
+// remove all ranges from the set
+void ggml_mem_ranges_reset(struct ggml_mem_ranges * mrs);
+
+// add src or dst ranges to track
+bool ggml_mem_ranges_add(struct ggml_mem_ranges * mrs, const struct ggml_tensor * tensor);
+
+// return false if:
+// - new src range overlaps with any existing dst range
+// - new dst range overlaps with any existing range (src or dst)
+bool ggml_mem_ranges_check(const struct ggml_mem_ranges * mrs, const struct ggml_tensor * tensor);
+
+// reorder the nodes in the graph to improve concurrency, while respecting fusion
+//
+// note: this implementation is generic and not specific to metal
+//       if it proves to work well, we can start using it for other backends in the future
+void ggml_metal_graph_optimize(struct ggml_cgraph * gf);
+
+#ifdef __cplusplus
+}
+#endif
index 6a42b3d7bc165c18285af9c9a81688998282892e..b8e06aa6f965eb1a40e827b967c0157c2072660e 100644 (file)
@@ -3,6 +3,7 @@
 #import "ggml-impl.h"
 #import "ggml-backend-impl.h"
 #import "ggml-metal-impl.h"
+#import "ggml-metal-common.h"
 
 #import <Foundation/Foundation.h>
 
@@ -61,8 +62,11 @@ static struct ggml_backend_metal_device_context {
     bool has_bfloat;
     bool use_bfloat;
     bool use_fusion;
+    bool use_concurrency;
     bool use_shared_buffers;
+    bool use_graph_optimize;
 
+    int debug_graph;
     int debug_fusion;
 
     // how many times a given op was fused
@@ -83,7 +87,10 @@ static struct ggml_backend_metal_device_context {
     /*.has_bfloat              =*/ false,
     /*.use_bfloat              =*/ false,
     /*.use_fusion              =*/ true,
+    /*.use_concurrency         =*/ true,
     /*.use_shared_buffers      =*/ true,
+    /*.use_graph_optimize      =*/ true,
+    /*.debug_graph             =*/ 0,
     /*.debug_fusion            =*/ 0,
     /*.fuse_cnt                =*/ { 0 },
     /*.max_size                =*/ 0,
@@ -124,7 +131,14 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
 #else
             ctx->use_bfloat = false;
 #endif
-            ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
+
+            ctx->use_fusion      = getenv("GGML_METAL_FUSION_DISABLE") == nil;
+            ctx->use_concurrency = getenv("GGML_METAL_CONCURRENCY_DISABLE") == nil;
+
+            {
+                const char * val = getenv("GGML_METAL_GRAPH_DEBUG");
+                ctx->debug_graph = val ? atoi(val) : 0;
+            }
 
             {
                 const char * val = getenv("GGML_METAL_FUSION_DEBUG");
@@ -137,6 +151,12 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
                 ctx->use_shared_buffers = false;
             }
 
+            ctx->use_graph_optimize = true;
+
+            if (getenv("GGML_METAL_GRAPH_OPTIMIZE_DISABLE") != NULL) {
+                ctx->use_graph_optimize = false;
+            }
+
             memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt));
 
             ctx->max_size = ctx->mtl_device.maxBufferLength;
@@ -628,7 +648,7 @@ static void ggml_metal_heap_free(struct ggml_metal_heap * heap) {
 @end
 
 //
-// ggml_metal_mem_pool
+// ggml_metal_mem_pool [TAG_MEM_POOL_REMOVE]
 //
 
 struct ggml_metal_mem_pool {
@@ -791,6 +811,9 @@ struct ggml_metal_command_buffer {
 
     // each command buffer has a memory pool from which it can allocate temporary buffers during the compute
     struct ggml_metal_mem_pool * mem_pool;
+
+    // used to enable concurrent execution of ops in the command buffers
+    struct ggml_mem_ranges * mem_ranges;
 };
 
 struct ggml_backend_metal_context {
@@ -1091,7 +1114,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
     GGML_LOG_INFO("%s: has bfloat            = %s\n", __func__, ctx_dev->has_bfloat                  ? "true" : "false");
     GGML_LOG_INFO("%s: use bfloat            = %s\n", __func__, ctx_dev->use_bfloat                  ? "true" : "false");
     GGML_LOG_INFO("%s: use fusion            = %s\n", __func__, ctx_dev->use_fusion                  ? "true" : "false");
+    GGML_LOG_INFO("%s: use concurrency       = %s\n", __func__, ctx_dev->use_concurrency             ? "true" : "false");
     GGML_LOG_INFO("%s: use shared buffers    = %s\n", __func__, ctx_dev->use_shared_buffers          ? "true" : "false");
+    GGML_LOG_INFO("%s: use graph optimize    = %s\n", __func__, ctx_dev->use_graph_optimize          ? "true" : "false");
     GGML_LOG_INFO("%s: hasUnifiedMemory      = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
 
     ctx->capture_next_compute = false;
@@ -1105,6 +1130,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
 
         ctx->cmd_bufs[i].mem_pool = ggml_metal_mem_pool_init();
         ctx->cmd_bufs[i].mem_pool->device = device;
+
+        if (ctx_dev->use_concurrency) {
+            ctx->cmd_bufs[i].mem_ranges = ggml_mem_ranges_init(ctx_dev->debug_graph);
+        }
     }
 
     ctx->cmd_bufs_ext = [[NSMutableArray alloc] init];
@@ -1715,6 +1744,10 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
         }
 
         ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool);
+
+        if (ctx->cmd_bufs[i].mem_ranges) {
+            ggml_mem_ranges_free(ctx->cmd_bufs[i].mem_ranges);
+        }
     }
 
     [ctx->cmd_bufs_ext removeAllObjects];
@@ -2071,12 +2104,51 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
     }
 }
 
-static int ggml_metal_encode_node(
-                        ggml_backend_t   backend,
-                                   int   idx,
-                                   int   idx_end,
-          id<MTLComputeCommandEncoder>   encoder,
-            struct ggml_metal_mem_pool * mem_pool) {
+struct ggml_metal_encode_context {
+    ggml_backend_t backend;
+
+    id<MTLComputeCommandEncoder> encoder;
+
+    struct ggml_metal_mem_pool * mem_pool;
+
+    struct ggml_mem_ranges * mem_ranges;
+};
+
+static bool ggml_metal_encode_concurrency_reset(struct ggml_metal_encode_context * ctx) {
+    if (!ctx->mem_ranges) {
+        return true;
+    }
+
+    [ctx->encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
+
+    ggml_mem_ranges_reset(ctx->mem_ranges);
+
+    return true;
+}
+
+static bool ggml_metal_encode_concurrency_check(struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) {
+    if (!ctx->mem_ranges) {
+        return false;
+    }
+
+    return ggml_mem_ranges_check(ctx->mem_ranges, node);
+}
+
+static bool ggml_metal_encode_concurrency_add(struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) {
+    if (!ctx->mem_ranges) {
+        return true;
+    }
+
+    return ggml_mem_ranges_add(ctx->mem_ranges, node);
+}
+
+static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, int idx, int idx_end) {
+    ggml_backend_t backend = ctx_enc->backend;
+
+    id<MTLComputeCommandEncoder> encoder = ctx_enc->encoder;
+
+    struct ggml_metal_mem_pool * mem_pool = ctx_enc->mem_pool;
+
     struct ggml_backend_metal_context        * ctx     = backend->context;
     struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
 
@@ -2159,38 +2231,71 @@ static int ggml_metal_encode_node(
     const uint64_t nb2  =  dst ?  dst->nb[2] : 0;
     const uint64_t nb3  =  dst ?  dst->nb[3] : 0;
 
+    size_t offs_src[GGML_MAX_SRC];
+
+    id<MTLBuffer> id_src[GGML_MAX_SRC];
+
+    enum ggml_type srct[GGML_MAX_SRC];
+
+    for (int i = 0; i < GGML_MAX_SRC; i++) {
+        offs_src[i] = 0;
+        id_src[i] = node->src[i] ? ggml_metal_get_buffer(node->src[i], &offs_src[i]) : nil;
+        srct[i]   = node->src[i] ? node->src[i]->type : GGML_TYPE_COUNT;
+    }
+
+    // TODO: tmp shorthands - remove
+    size_t offs_src0 = offs_src[0];
+    size_t offs_src1 = offs_src[1];
+    size_t offs_src2 = offs_src[2];
+
+    id<MTLBuffer> id_src0 = id_src[0];
+    id<MTLBuffer> id_src1 = id_src[1];
+    id<MTLBuffer> id_src2 = id_src[2];
+
     const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
     const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
     const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT;
     const enum ggml_type dstt  = dst  ? dst->type  : GGML_TYPE_COUNT;
 
-    size_t offs_src0 = 0;
-    size_t offs_src1 = 0;
-    size_t offs_src2 = 0;
-    size_t offs_dst  = 0;
+    size_t offs_dst = 0;
 
-    id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
-    id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
-    id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
-    id<MTLBuffer> id_dst  = dst  ? ggml_metal_get_buffer(dst,  &offs_dst)  : nil;
+    id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
 
     int n_fuse = 1;
 
-#if 0
-    GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
-    if (src0) {
-        GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
-                ggml_is_contiguous(src0), src0->name);
-    }
-    if (src1) {
-        GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
-                ggml_is_contiguous(src1), src1->name);
-    }
-    if (dst) {
-        GGML_LOG_INFO("%s: dst  - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
-                dst->name);
+    // check if the current node can run concurrently with other nodes before it
+    // the condition is that:
+    //  - the current node cannot write to any previous src or dst ranges
+    //  - the current node cannot read from any previous dst ranges
+    //
+    // if the condition is not satisfied, we put a memory barrier and clear all ranges
+    // otherwise, we add the new ranges to the encoding context and process the node concurrently
+    //
+    {
+        const bool is_concurrent = ggml_metal_encode_concurrency_check(ctx_enc, node);
+
+        if (!is_concurrent) {
+            ggml_metal_encode_concurrency_reset(ctx_enc);
+        }
+
+        if (ctx_dev->debug_graph > 0) {
+            GGML_LOG_DEBUG("%s: node[%5d] - %-12s %s\n", __func__, idx, ggml_op_name(dst->op), is_concurrent ? "(concurrent)" : "");
+        }
+        if (ctx_dev->debug_graph > 1) {
+            if (src0) {
+                GGML_LOG_DEBUG("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
+                        ggml_is_contiguous(src0), src0->name);
+            }
+            if (src1) {
+                GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
+                        ggml_is_contiguous(src1), src1->name);
+            }
+            if (dst) {
+                GGML_LOG_DEBUG("%s: dst  - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
+                        dst->name);
+            }
+        }
     }
-#endif
 
     id<MTLDevice> device = ctx_dev->mtl_device;
 
@@ -2389,6 +2494,14 @@ static int ggml_metal_encode_node(
 
                 if (n_fuse > 1) {
                     id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
+
+                    for (int i = 1; i < n_fuse; ++i) {
+                        if (!ggml_metal_encode_concurrency_check(ctx_enc, nodes[i])) {
+                            ggml_metal_encode_concurrency_reset(ctx_enc);
+
+                            break;
+                        }
+                    }
                 }
 
                 [encoder setComputePipelineState:pipeline];
@@ -2533,6 +2646,8 @@ static int ggml_metal_encode_node(
                     const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
 
                     [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+
+                    ggml_metal_encode_concurrency_reset(ctx_enc);
                 }
 
                 const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
@@ -3997,6 +4112,12 @@ static int ggml_metal_encode_node(
                         default: break;
                     }
 
+                    // TODO: using mem pool allocations with enabled concurrency is not safe because the mem pool
+                    // reuses buffers. this can result in 2 concurrent MUL_MAT_ID ops using the same mem pool buffer.
+                    // so we add this extra barrier to prevent the race.
+                    // the correct solution is to remove mem pools and then remove this barrier [TAG_MEM_POOL_REMOVE]
+                    ggml_metal_encode_concurrency_reset(ctx_enc);
+
                     // tokens per expert
                     const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02;
                     id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
@@ -4057,6 +4178,9 @@ static int ggml_metal_encode_node(
                         [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(ne02, 1, 1)];
                     }
 
+                    // this barrier is always needed because the next kernel has to wait for the id maps to be computed
+                    ggml_metal_encode_concurrency_reset(ctx_enc);
+
                     {
                         id<MTLComputePipelineState> pipeline = nil;
 
@@ -4525,6 +4649,14 @@ static int ggml_metal_encode_node(
 
                 if (n_fuse > 1) {
                     id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
+
+                    for (int i = 1; i < n_fuse; ++i) {
+                        if (!ggml_metal_encode_concurrency_check(ctx_enc, nodes[i])) {
+                            ggml_metal_encode_concurrency_reset(ctx_enc);
+
+                            break;
+                        }
+                    }
                 }
 
                 id<MTLComputePipelineState> pipeline;
@@ -4668,7 +4800,6 @@ static int ggml_metal_encode_node(
             } break;
         case GGML_OP_ROPE:
             {
-
                 // make sure we have one or more position id(ne10) per token(ne02)
                 GGML_ASSERT(ne10 % ne02 == 0);
                 GGML_ASSERT(ne10 >= ne02);
@@ -5427,6 +5558,10 @@ static int ggml_metal_encode_node(
                         GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
                         GGML_ASSERT(ne1*ne2*ne3 <= (1u << 31));
 
+                        // using mem pool allocations with enabled concurrency is not safe [TAG_MEM_POOL_REMOVE]
+                        // still, we assume that concurrent FA won't happen before we do the refactor
+                        //ggml_metal_encode_concurrency_reset(ctx_enc);
+
                         const int32_t nrows = ne1*ne2*ne3;
 
                         // temp buffer for writing the results from each workgroup
@@ -5447,6 +5582,8 @@ static int ggml_metal_encode_node(
                         [encoder setThreadgroupMemoryLength:smem atIndex:0];
                         [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
 
+                        ggml_metal_encode_concurrency_reset(ctx_enc);
+
                         // reduce the results from the workgroups
                         {
                             ggml_metal_kargs_flash_attn_ext_vec_reduce args0 = {
@@ -5677,7 +5814,7 @@ static int ggml_metal_encode_node(
 
                 [encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
             } break;
-            case GGML_OP_ARGMAX:
+        case GGML_OP_ARGMAX:
             {
                 GGML_ASSERT(src0->type == GGML_TYPE_F32);
                 GGML_ASSERT(ggml_is_contiguous_1(src0));
@@ -5709,6 +5846,19 @@ static int ggml_metal_encode_node(
             }
     }
 
+    if (ctx_dev->debug_graph > 0) {
+        if (n_fuse > 1) {
+            GGML_LOG_DEBUG("%s:               fuse %d ops\n", __func__, n_fuse);
+        }
+    }
+
+    // update the mem ranges in the encoding context
+    for (int i = 0; i < n_fuse; ++i) {
+        if (!ggml_metal_encode_concurrency_add(ctx_enc, nodes[i])) {
+            ggml_metal_encode_concurrency_reset(ctx_enc);
+        }
+    }
+
     return n_fuse;
 }
 
@@ -5719,7 +5869,7 @@ static enum ggml_status ggml_metal_graph_compute(
     struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
 
     // number of nodes encoded by the main thread (empirically determined)
-    const int n_main = 128;
+    const int n_main = 64;
 
     // number of threads in addition to the main thread
     const int n_cb = ctx->n_cb;
@@ -5774,6 +5924,7 @@ static enum ggml_status ggml_metal_graph_compute(
             // cannot use commandBufferWithUnretainedReferences because the buffers from the memory pool can get destroyed
             // TODO: when the memory pools are removed, we can again use commandBufferWithUnretainedReferences
             //       https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2334215009
+            // [TAG_MEM_POOL_REMOVE]
             //id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
             id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBuffer];
             [cmd_buf retain];
@@ -6547,6 +6698,18 @@ static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend,
     return ggml_metal_graph_compute(backend, cgraph);
 }
 
+static void ggml_backend_metal_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+    struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
+
+    //const int64_t t_start = ggml_time_us();
+
+    if (ctx_dev->use_graph_optimize) {
+        ggml_metal_graph_optimize(cgraph);
+    }
+
+    //printf("%s: graph optimize took %.3f ms\n", __func__, (ggml_time_us() - t_start) / 1000.0);
+}
+
 static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
     GGML_ASSERT(ggml_backend_is_metal(backend));
 
@@ -6573,12 +6736,25 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
 
         const int n_nodes_per_cb = ctx->n_nodes_per_cb;
 
-        id<MTLCommandBuffer>         cmd_buf  = ctx->cmd_bufs[cb_idx].obj;
-        struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
+        id<MTLCommandBuffer>         cmd_buf    = ctx->cmd_bufs[cb_idx].obj;
+        struct ggml_metal_mem_pool * mem_pool   = ctx->cmd_bufs[cb_idx].mem_pool;
+        struct ggml_mem_ranges     * mem_ranges = ctx->cmd_bufs[cb_idx].mem_ranges;
 
         ggml_metal_mem_pool_reset(mem_pool);
 
-        id<MTLComputeCommandEncoder> encoder = [cmd_buf computeCommandEncoder];
+        if (mem_ranges) {
+            ggml_mem_ranges_reset(mem_ranges);
+        }
+
+        id<MTLComputeCommandEncoder> encoder;
+
+        struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
+
+        if (ctx_dev->use_concurrency) {
+            encoder = [cmd_buf computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
+        } else {
+            encoder = [cmd_buf computeCommandEncoder];
+        }
 
         int node_start = 0;
         int node_end   = n_nodes_0;
@@ -6590,12 +6766,19 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
 
         const bool should_capture = ctx->capture_next_compute;
 
+        struct ggml_metal_encode_context ctx_enc = {
+            /*.backend    =*/ backend,
+            /*.encoder    =*/ encoder,
+            /*.mem_pool   =*/ mem_pool,
+            /*.mem_ranges =*/ mem_ranges,
+        };
+
         for (int idx = node_start; idx < node_end;) {
             if (should_capture) {
                 [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
             }
 
-            const int res = ggml_metal_encode_node(backend, idx, node_end, encoder, mem_pool);
+            const int res = ggml_metal_encode_node(&ctx_enc, idx, node_end);
             if (idx + res > node_end) {
                 GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s",
                         "https://github.com/ggml-org/llama.cpp/pull/14849");
@@ -6638,7 +6821,7 @@ static struct ggml_backend_i ggml_backend_metal_i = {
     // https://developer.apple.com/documentation/metal/mtlcommandbuffer#Synchronizing-Passes-with-Events
     /* .event_record            = */ NULL,
     /* .event_wait              = */ NULL,
-    /* .optimize_graph          = */ NULL,
+    /* .optimize_graph          = */ ggml_backend_metal_graph_optimize,
 };
 
 static ggml_guid_t ggml_backend_metal_guid(void) {