]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : remove memory pools (#15966)
authorGeorgi Gerganov <redacted>
Sun, 14 Sep 2025 19:02:32 +0000 (22:02 +0300)
committerGitHub <redacted>
Sun, 14 Sep 2025 19:02:32 +0000 (22:02 +0300)
* metal : remove mem pool usage

ggml-ci

* metal : remove mem pool implementation

ggml-ci

* metal : take into account the actual allocated memory of the tensor

ggml-ci

* cont : use ggml_backend_buft_get_alloc_size

ggml-ci

* cont : improve, comments

ggml-ci

* cont : add functions for the extra tensor sizes

* metal : add comments

ggml-ci

* metal : implement .get_alloc_size for the rest of the buffer types

ggml-ci

* metal : remove ggml_metal_heap

ggml-ci

ggml/src/ggml-metal/ggml-metal-common.cpp
ggml/src/ggml-metal/ggml-metal.m

index 6a869ff24cd8db08409ab001bb5c64bebc1cbd19..cb39e5b2ab5bbf1636b5b1b20d2d18d92096b31c 100644 (file)
@@ -1,9 +1,12 @@
 #include "ggml-metal-common.h"
 
 #include "ggml-impl.h"
+#include "ggml-backend-impl.h"
 
 #include <vector>
 
+// represents a memory range (i.e. an interval from a starting address p0 to an ending address p1 in a given buffer pb)
+// the type indicates whether it is a source range (i.e. ops read data from it) or a destination range (i.e. ops write data to it)
 struct ggml_mem_range {
     uint64_t pb; // buffer id
 
@@ -36,8 +39,8 @@ void ggml_mem_ranges_reset(ggml_mem_ranges * mrs) {
     mrs->ranges.clear();
 }
 
-static bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, ggml_mem_range mrp) {
-    mrs->ranges.push_back(mrp);
+static bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, ggml_mem_range mr) {
+    mrs->ranges.push_back(mr);
 
     return true;
 }
@@ -48,20 +51,24 @@ static ggml_mem_range ggml_mem_range_from_tensor(const ggml_tensor * tensor, ggm
 
     GGML_ASSERT(!tensor->view_src);
 
-    ggml_mem_range mrp;
+    ggml_mem_range mr;
 
     if (tensor->buffer) {
-        // when the tensor is allocated, use the actual memory address range of the buffer
-        mrp = {
+        // when the tensor is allocated, use the actual memory address range in the buffer
+        //
+        // take the actual allocated size with ggml_backend_buft_get_alloc_size()
+        // this can be larger than the tensor size if the buffer type allocates extra memory
+        // ref: https://github.com/ggml-org/llama.cpp/pull/15966
+        mr = {
             /*.pb =*/ (uint64_t) tensor->buffer,
             /*.p0 =*/ (uint64_t) tensor->data,
-            /*.p1 =*/ (uint64_t) tensor->data + ggml_nbytes(tensor),
+            /*.p1 =*/ (uint64_t) tensor->data + ggml_backend_buft_get_alloc_size(tensor->buffer->buft, tensor),
             /*.pt =*/ pt,
         };
     } else {
-        // otherwise, the tensor ptr is used as an unique id of the memory ranges
+        // otherwise, the pointer address is used as an unique id of the memory ranges
         //   that the tensor will be using when it is allocated
-        mrp = {
+        mr = {
             /*.pb =*/ (uint64_t) tensor,
             /*.p0 =*/ 0,    //
             /*.p1 =*/ 1024, // [0, 1024) is a dummy range, not used
@@ -69,7 +76,7 @@ static ggml_mem_range ggml_mem_range_from_tensor(const ggml_tensor * tensor, ggm
         };
     };
 
-    return mrp;
+    return mr;
 }
 
 static ggml_mem_range ggml_mem_range_from_tensor_src(const ggml_tensor * tensor) {
@@ -83,25 +90,25 @@ static ggml_mem_range ggml_mem_range_from_tensor_dst(const ggml_tensor * tensor)
 static bool ggml_mem_ranges_add_src(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
     GGML_ASSERT(tensor);
 
-    ggml_mem_range mrp = ggml_mem_range_from_tensor_src(tensor);
+    ggml_mem_range mr = ggml_mem_range_from_tensor_src(tensor);
 
     if (mrs->debug > 2) {
-        GGML_LOG_DEBUG("%s: add src range buf=%lld, [%lld, %lld)\n", __func__, mrp.pb, mrp.p0, mrp.p1);
+        GGML_LOG_DEBUG("%s: add src range buf=%lld, [%lld, %lld)\n", __func__, mr.pb, mr.p0, mr.p1);
     }
 
-    return ggml_mem_ranges_add(mrs, mrp);
+    return ggml_mem_ranges_add(mrs, mr);
 }
 
 static bool ggml_mem_ranges_add_dst(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
     GGML_ASSERT(tensor);
 
-    ggml_mem_range mrp = ggml_mem_range_from_tensor_dst(tensor);
+    ggml_mem_range mr = ggml_mem_range_from_tensor_dst(tensor);
 
     if (mrs->debug > 2) {
-        GGML_LOG_DEBUG("%s: add dst range buf=%lld, [%lld, %lld)\n", __func__, mrp.pb, mrp.p0, mrp.p1);
+        GGML_LOG_DEBUG("%s: add dst range buf=%lld, [%lld, %lld)\n", __func__, mr.pb, mr.p0, mr.p1);
     }
 
-    return ggml_mem_ranges_add(mrs, mrp);
+    return ggml_mem_ranges_add(mrs, mr);
 }
 
 bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
@@ -114,24 +121,26 @@ bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
     return ggml_mem_ranges_add_dst(mrs, tensor);
 }
 
-static bool ggml_mem_ranges_check(const ggml_mem_ranges * mrs, ggml_mem_range mrp) {
+static bool ggml_mem_ranges_check(const ggml_mem_ranges * mrs, ggml_mem_range mr) {
     for (size_t i = 0; i < mrs->ranges.size(); i++) {
         const auto & cmp = mrs->ranges[i];
 
-        if (mrp.pb != cmp.pb) {
+        // two memory ranges cannot intersect if they are in different buffers
+        if (mr.pb != cmp.pb) {
             continue;
         }
 
-        if (mrp.pt == MEM_RANGE_TYPE_SRC && cmp.pt == MEM_RANGE_TYPE_SRC) {
+        // intersecting source ranges are allowed
+        if (mr.pt == MEM_RANGE_TYPE_SRC && cmp.pt == MEM_RANGE_TYPE_SRC) {
             continue;
         }
 
-        if (mrp.p0 < cmp.p1 && mrp.p1 >= cmp.p0) {
+        if (mr.p0 < cmp.p1 && mr.p1 >= cmp.p0) {
             if (mrs->debug > 2) {
                 GGML_LOG_DEBUG("%s: the %s range buf=%lld, [%lld, %lld) overlaps with a previous %s range buf=%lld, [%lld, %lld)\n",
                         __func__,
-                        mrp.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst",
-                        mrp.pb, mrp.p0, mrp.p1,
+                        mr.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst",
+                        mr.pb, mr.p0, mr.p1,
                         cmp.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst",
                         cmp.pb, cmp.p0, cmp.p1);
             }
@@ -146,9 +155,9 @@ static bool ggml_mem_ranges_check(const ggml_mem_ranges * mrs, ggml_mem_range mr
 static bool ggml_mem_ranges_check_src(const ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
     GGML_ASSERT(tensor);
 
-    ggml_mem_range mrp = ggml_mem_range_from_tensor_src(tensor);
+    ggml_mem_range mr = ggml_mem_range_from_tensor_src(tensor);
 
-    const bool res = ggml_mem_ranges_check(mrs, mrp);
+    const bool res = ggml_mem_ranges_check(mrs, mr);
 
     return res;
 }
@@ -156,9 +165,9 @@ static bool ggml_mem_ranges_check_src(const ggml_mem_ranges * mrs, const ggml_te
 static bool ggml_mem_ranges_check_dst(const ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
     GGML_ASSERT(tensor);
 
-    ggml_mem_range mrp = ggml_mem_range_from_tensor_dst(tensor);
+    ggml_mem_range mr = ggml_mem_range_from_tensor_dst(tensor);
 
-    const bool res = ggml_mem_ranges_check(mrs, mrp);
+    const bool res = ggml_mem_ranges_check(mrs, mr);
 
     return res;
 }
@@ -222,6 +231,7 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
             }
         }
 
+        // keep track of the sources of the fused nodes as well
         for (const auto * fused : node.fused) {
             for (int i = 0; i < GGML_MAX_SRC; i++) {
                 if (fused->src[i]) {
@@ -290,7 +300,10 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
 
     std::vector<bool> used(n, false);
 
+    // the memory ranges for the set of currently concurrent nodes
     ggml_mem_ranges * mrs0 = ggml_mem_ranges_init(0);
+
+    // the memory ranges for the set of nodes that haven't been processed yet, when looking forward for a node to reorder
     ggml_mem_ranges * mrs1 = ggml_mem_ranges_init(0);
 
     for (int i0 = 0; i0 < n; i0++) {
@@ -329,7 +342,7 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
 
                 const bool is_empty = node1.is_empty();
 
-                // to add a concurrent node, it has to be:
+                // to reorder a node and add it to the concurrent set, it has to be:
                 //   + empty or concurrent with all nodes in the existing concurrent set (mrs0)
                 //   + concurrent with all nodes prior to it that haven't been processed yet (mrs1)
                 if ((is_empty || h_check(mrs0, node1)) && h_check(mrs1, node1)) {
@@ -419,8 +432,8 @@ void ggml_metal_graph_optimize(ggml_cgraph * gf) {
         nodes.push_back(std::move(node));
     }
 
-    // reorder to improve concurrency
 #if 1
+    // reorder to improve concurrency
     const auto order = ggml_metal_graph_optimize_reorder(nodes);
 #else
     std::vector<int> order(nodes.size());
index 13f9de297eae480d05f9cefdcbc5330177093315..2243c174fb71320a239c9675729f8b5d909b27b9 100644 (file)
@@ -532,261 +532,9 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_COUNT
 };
 
-//
-// ggml_metal_heap
-//
-
-struct ggml_metal_heap {
-    // number of times the heap was unused
-    int n_unused;
-
-    // total number of buffer allocations in this heap across all computes
-    int64_t n_alloc;
-
-    // current offset in the heap - we reset this after each node in order to reuse the memory
-    size_t offs;
-
-    // the currently allocated MTLBuffer objects in this heap
-    id<MTLHeap> obj;
-
-    NSMutableArray * bufs;
-};
-
-static struct ggml_metal_heap * ggml_metal_heap_init(id<MTLDevice> device, size_t size) {
-    struct ggml_metal_heap * heap = calloc(1, sizeof(struct ggml_metal_heap));
-
-    MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc] init];
-    desc.storageMode  = MTLStorageModePrivate;
-    desc.cpuCacheMode = MTLCPUCacheModeDefaultCache;
-    desc.type         = MTLHeapTypePlacement;
-    desc.size         = size;
-
-    heap->n_unused = 0;
-    heap->n_alloc = 0;
-
-    heap->obj = [device newHeapWithDescriptor:desc];
-    if (!heap->obj) {
-        GGML_LOG_ERROR("%s: error: failed to create MTLHeap with size %zu\n", __func__, size);
-
-        free(heap);
-
-        return false;
-    }
-
-    [desc release];
-
-    heap->bufs = [[NSMutableArray alloc] init];
-
-    return heap;
-}
-
-static void ggml_metal_heap_reset(struct ggml_metal_heap * heap) {
-    heap->offs = 0;
-
-    // count how many graph computes the heap ended up being unused
-    if ([heap->bufs count] > 0) {
-        heap->n_unused = 0;
-    } else {
-        heap->n_unused++;
-    }
-
-    for (id<MTLBuffer> buf in heap->bufs) {
-        [buf release];
-    }
-    [heap->bufs removeAllObjects];
-
-    // tell the OS that it can reuse this memory if needed
-    // ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
-    [heap->obj setPurgeableState:MTLPurgeableStateVolatile];
-}
-
-static void ggml_metal_heap_free(struct ggml_metal_heap * heap) {
-    if (heap == nil) {
-        return;
-    }
-
-    ggml_metal_heap_reset(heap);
-
-    [heap->obj  release];
-    [heap->bufs release];
-
-    free(heap);
-}
-
-@interface ggml_metal_heap_ptr : NSObject
-
-@property (nonatomic, assign) struct ggml_metal_heap * data;
-
-@end
-
-@implementation ggml_metal_heap_ptr
-@end
-
-//
-// ggml_metal_mem_pool [TAG_MEM_POOL_REMOVE]
-//
-
-struct ggml_metal_mem_pool {
-    id<MTLDevice> device;
-
-    int n_heaps; // total number of heaps ever created (including those that were removed)
-
-    NSMutableArray * heaps;
-    NSMutableArray * heaps_to_remove;
-};
-
-static struct ggml_metal_mem_pool * ggml_metal_mem_pool_init(void) {
-    struct ggml_metal_mem_pool * mem_pool = calloc(1, sizeof(struct ggml_metal_mem_pool));
-
-    mem_pool->n_heaps = 0;
-
-    mem_pool->heaps           = [[NSMutableArray alloc] init];
-    mem_pool->heaps_to_remove = [[NSMutableArray alloc] init];
-
-    return mem_pool;
-}
-
-static void ggml_metal_mem_pool_free(struct ggml_metal_mem_pool * mem_pool) {
-    GGML_LOG_DEBUG("%s: freeing memory pool, num heaps = %zu (total = %d)\n", __func__, [mem_pool->heaps count], mem_pool->n_heaps);
-
-    size_t size_all = 0;
-    size_t size_cur = 0;
-
-    for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
-        GGML_LOG_DEBUG("%s:   heap: %p\n",                __func__, (void *) ptr.data);
-        GGML_LOG_DEBUG("%s:     n_alloc:  %" PRId64 "\n", __func__, ptr.data->n_alloc);
-        GGML_LOG_DEBUG("%s:     n_unused: %d\n",          __func__, ptr.data->n_unused);
-        GGML_LOG_DEBUG("%s:     size:     %.2f MiB\n",    __func__, [ptr.data->obj size] / 1024.0 / 1024.0);
-        GGML_LOG_DEBUG("%s:     bufs:     %zu\n",         __func__, [ptr.data->bufs count]);
-
-        if ([ptr.data->bufs count] > 0) {
-            size_cur += [ptr.data->obj size];
-        }
-        size_all += [ptr.data->obj size];
-
-        ggml_metal_heap_free(ptr.data);
-        [ptr release];
-    }
-    [mem_pool->heaps           release];
-    [mem_pool->heaps_to_remove release];
-
-    if (size_all > 0) {
-        GGML_LOG_DEBUG("%s:   size_all: %.2f MiB\n", __func__, size_all / 1024.0 / 1024.0);
-        GGML_LOG_DEBUG("%s:   size_cur: %.2f MiB\n", __func__, size_cur / 1024.0 / 1024.0);
-    }
-
-    free(mem_pool);
-}
-
-static void ggml_metal_mem_pool_reset(struct ggml_metal_mem_pool * mem_pool) {
-    for (NSUInteger i = 0; i < [mem_pool->heaps count]; i++) {
-        ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:i];
-
-        struct ggml_metal_heap * heap = ptr.data;
-        ggml_metal_heap_reset(heap);
-
-        // if the heap hasn't been used for a while, remove it
-        if (heap->n_unused >= 128) {
-            [mem_pool->heaps_to_remove addObject:@(i)];
-        }
-    }
-
-    if (mem_pool->heaps_to_remove.count > 0) {
-        // remove in reverse order
-        for (NSUInteger i = [mem_pool->heaps_to_remove count] - 1; ; --i) {
-            NSUInteger index = [[mem_pool->heaps_to_remove objectAtIndex:i] intValue];
-            ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:index];
-
-            struct ggml_metal_heap * heap = ptr.data;
-            ggml_metal_heap_free(heap);
-
-            [mem_pool->heaps removeObjectAtIndex:index];
-            [ptr release];
-
-            if (i == 0) {
-                break;
-            }
-        }
-
-        [mem_pool->heaps_to_remove removeAllObjects];
-    }
-}
-
-static void ggml_metal_mem_pool_clear(struct ggml_metal_mem_pool * mem_pool) {
-    for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
-        ptr.data->offs = 0;
-    }
-}
-
-static id<MTLBuffer> ggml_metal_mem_pool_alloc(struct ggml_metal_mem_pool * mem_pool, size_t size) {
-    const size_t alignment = 256;
-
-    const size_t size_aligned = GGML_PAD(size, alignment);
-
-    // try one of the existing heaps
-    for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
-        struct ggml_metal_heap * heap = ptr.data;
-        if (heap->offs + size_aligned <= [heap->obj size]) {
-            // if this is the first buffer in the heap for the current command buffer, tell the OS that
-            //   it cannot free the memory used by the heap
-            // ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
-            if ([heap->bufs count] == 0) {
-                [heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
-            }
-
-            id<MTLBuffer> buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
-            if (buf == nil) {
-                GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
-                return nil;
-            }
-
-            heap->n_alloc++;
-            heap->offs += size_aligned;
-
-            [heap->bufs addObject:buf];
-
-            return buf;
-        }
-    }
-
-    // create a new heap that can fit this buffer
-    ggml_metal_heap_ptr * heap_ptr = [ggml_metal_heap_ptr new];
-
-    struct ggml_metal_heap * heap = ggml_metal_heap_init(mem_pool->device, size_aligned);
-    if (heap == NULL) {
-        GGML_LOG_ERROR("%s: error: failed to create heap of size %zu\n", __func__, size_aligned);
-        return NULL;
-    }
-
-    //GGML_LOG_DEBUG("%s: creating new heap of size %zu, got %zu\n", __func__, size_aligned, [heap->obj size]);
-
-    heap_ptr.data = heap;
-    ggml_metal_heap_reset(heap);
-
-    [heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
-    id<MTLBuffer> buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
-    if (buf == nil) {
-        GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
-        return NULL;
-    }
-
-    heap->n_alloc++;
-    heap->offs += size_aligned;
-
-    [heap->bufs addObject:buf];
-
-    [mem_pool->heaps addObject:heap_ptr];
-    mem_pool->n_heaps++;
-
-    return buf;
-}
-
 struct ggml_metal_command_buffer {
     id<MTLCommandBuffer> obj;
 
-    // each command buffer has a memory pool from which it can allocate temporary buffers during the compute
-    struct ggml_metal_mem_pool * mem_pool;
-
     // used to enable concurrent execution of ops in the command buffers
     struct ggml_mem_ranges * mem_ranges;
 };
@@ -1103,9 +851,6 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
     for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
         ctx->cmd_bufs[i].obj = nil;
 
-        ctx->cmd_bufs[i].mem_pool = ggml_metal_mem_pool_init();
-        ctx->cmd_bufs[i].mem_pool->device = device;
-
         if (ctx_dev->use_concurrency) {
             ctx->cmd_bufs[i].mem_ranges = ggml_mem_ranges_init(ctx_dev->debug_graph);
         }
@@ -1510,6 +1255,52 @@ static id<MTLComputePipelineState> ggml_metal_compile_kernel(ggml_backend_t back
     return res;
 }
 
+// tokens per expert
+static size_t ggml_metal_mul_mat_id_extra_tpe(const struct ggml_tensor * op) {
+    assert(op->op == GGML_OP_MUL_MAT_ID);
+
+    const int64_t ne02 = op->src[0]->ne[2]; // n_expert
+
+    return ggml_type_size(GGML_TYPE_I32)*ne02;
+}
+
+// id map [n_tokens, n_expert]
+static size_t ggml_metal_mul_mat_id_extra_ids(const struct ggml_tensor * op) {
+    assert(op->op == GGML_OP_MUL_MAT_ID);
+
+    const int64_t ne02 = op->src[0]->ne[2]; // n_expert
+    const int64_t ne21 = op->src[2]->ne[1]; // n_token
+
+    return ggml_type_size(GGML_TYPE_I32)*ne02*ne21;
+}
+
+// return true if we should use the FA vector kernel for this op
+static bool ggml_metal_flash_attn_ext_use_vec(const struct ggml_tensor * op) {
+    assert(op->op == GGML_OP_FLASH_ATTN_EXT);
+
+    const int64_t ne00 = op->src[0]->ne[0]; // head size
+    const int64_t ne01 = op->src[0]->ne[1]; // batch size
+
+    // use vec kernel if the batch size is small and if the head size is supported
+    return (ne01 < 20) && (ne00 % 32 == 0);
+}
+
+static size_t ggml_metal_flash_attn_ext_extra_tmp(const struct ggml_tensor * op) {
+    assert(op->op == GGML_OP_FLASH_ATTN_EXT);
+
+    const int64_t nwg = 32;
+
+    const int64_t ne01 = op->src[0]->ne[1];
+    const int64_t ne02 = op->src[0]->ne[2];
+    const int64_t ne03 = op->src[0]->ne[3];
+    const int64_t ne20 = op->src[2]->ne[0];
+
+    // temp buffer for writing the results from each workgroup
+    // - ne20: the size of the Value head
+    // -  + 2: the S and M values for each intermediate result
+    return ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
+}
+
 static id<MTLComputePipelineState> ggml_metal_get_pipeline_flash_attn_ext(
         ggml_backend_t backend, struct ggml_tensor * op,
         bool    has_mask,
@@ -1760,8 +1551,6 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
             [ctx->cmd_bufs[i].obj release];
         }
 
-        ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool);
-
         if (ctx->cmd_bufs[i].mem_ranges) {
             ggml_mem_ranges_free(ctx->cmd_bufs[i].mem_ranges);
         }
@@ -2127,8 +1916,6 @@ struct ggml_metal_encode_context {
 
     id<MTLComputeCommandEncoder> encoder;
 
-    struct ggml_metal_mem_pool * mem_pool;
-
     struct ggml_mem_ranges * mem_ranges;
 };
 
@@ -2165,8 +1952,6 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
 
     id<MTLComputeCommandEncoder> encoder = ctx_enc->encoder;
 
-    struct ggml_metal_mem_pool * mem_pool = ctx_enc->mem_pool;
-
     struct ggml_backend_metal_context        * ctx     = backend->context;
     struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
 
@@ -2207,8 +1992,6 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
         GGML_ABORT("unsupported op");
     }
 
-    ggml_metal_mem_pool_clear(mem_pool);
-
     const int64_t  ne00 = src0 ? src0->ne[0] : 0;
     const int64_t  ne01 = src0 ? src0->ne[1] : 0;
     const int64_t  ne02 = src0 ? src0->ne[2] : 0;
@@ -2522,7 +2305,6 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
                     /*.nb02 =*/ nb02,
                     /*.nb11 =*/ nb11,
                     /*.nb21 =*/ nb21,
-
                 };
 
                 [encoder setComputePipelineState:pipeline];
@@ -3167,54 +2949,8 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
                 const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
                 const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
-// use this branch to test the ggml_metal_mem_pool functionality
-#if 0
-                // cpy to tmp buffer in MTLHeap
-
-                id<MTLBuffer> h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
-                if (!h_src0) {
-                    GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0));
-                    return 0;
-                }
-
-                offs_src0 = 0;
-
-                ggml_metal_kargs_cpy args_cpy = {
-                    /*.ne00 =*/ ne00,
-                    /*.ne01 =*/ ne01,
-                    /*.ne02 =*/ ne02,
-                    /*.ne03 =*/ ne03,
-                    /*.nb00 =*/ nb00,
-                    /*.nb01 =*/ nb01,
-                    /*.nb02 =*/ nb02,
-                    /*.nb03 =*/ nb03,
-                    /*.ne0  =*/ ne00,
-                    /*.ne1  =*/ ne01,
-                    /*.ne2  =*/ ne02,
-                    /*.ne3  =*/ ne03,
-                    /*.nb0  =*/ nb00,
-                    /*.nb1  =*/ nb01,
-                    /*.nb2  =*/ nb02,
-                    /*.nb3  =*/ nb03,
-                };
-
-                if (src0->type == GGML_TYPE_F16) {
-                    [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline];
-                } else {
-                    [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline];
-                }
-                [encoder setBytes:&args_cpy length:sizeof(args_cpy) atIndex:0];
-                [encoder setBuffer:id_src0  offset:offs_src0        atIndex:1];
-                [encoder setBuffer:h_src0   offset:0                atIndex:2];
-
-                GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
-                int nth_cpy = MIN(1024, ne00 / ggml_blck_size(src0->type));
-
-                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth_cpy, 1, 1)];
-
-#else
                 id<MTLBuffer> h_src0 = id_src0;
-#endif
+
                 // softmax
 
                 ggml_metal_kargs_soft_max args = {
@@ -4093,28 +3829,9 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
                         default: break;
                     }
 
-                    // TODO: using mem pool allocations with enabled concurrency is not safe because the mem pool
-                    // reuses buffers. this can result in 2 concurrent MUL_MAT_ID ops using the same mem pool buffer.
-                    // so we add this extra barrier to prevent the race.
-                    // the correct solution is to remove mem pools and then remove this barrier [TAG_MEM_POOL_REMOVE]
-                    ggml_metal_encode_concurrency_reset(ctx_enc);
-
-                    // tokens per expert
-                    const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02;
-                    id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
-                    if (!h_tpe) {
-                        GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe);
-                        return 0;
-                    }
-
-                    // id map
-                    // [n_tokens, n_expert]
-                    const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne21*ne02;
-                    id<MTLBuffer> h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
-                    if (!h_ids) {
-                        GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
-                        return 0;
-                    }
+                    // extra buffers for intermediate id mapping
+                    size_t offs_tpe = offs_dst + ggml_nbytes(dst);
+                    size_t offs_ids = offs_tpe + ggml_metal_mul_mat_id_extra_tpe(dst);
 
                     {
                         ggml_metal_kargs_mul_mm_id_map0 args = {
@@ -4152,8 +3869,8 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
                         [encoder setComputePipelineState:pipeline];
                         [encoder setBytes:&args    length:sizeof(args) atIndex:0];
                         [encoder setBuffer:id_src2 offset:offs_src2    atIndex:1];
-                        [encoder setBuffer: h_tpe  offset:0            atIndex:2];
-                        [encoder setBuffer: h_ids  offset:0            atIndex:3];
+                        [encoder setBuffer:id_dst  offset:offs_tpe     atIndex:2];
+                        [encoder setBuffer:id_dst  offset:offs_ids     atIndex:3];
                         [encoder setThreadgroupMemoryLength:smem atIndex:0];
 
                         [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(ne02, 1, 1)];
@@ -4215,8 +3932,8 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
                         [encoder setBytes:&args    length:sizeof(args) atIndex:0];
                         [encoder setBuffer:id_src0 offset:offs_src0    atIndex:1];
                         [encoder setBuffer:id_src1 offset:offs_src1    atIndex:2];
-                        [encoder setBuffer: h_tpe  offset:0            atIndex:3];
-                        [encoder setBuffer: h_ids  offset:0            atIndex:4];
+                        [encoder setBuffer:id_dst  offset:offs_tpe     atIndex:3];
+                        [encoder setBuffer:id_dst  offset:offs_ids     atIndex:4];
                         [encoder setBuffer:id_dst  offset:offs_dst     atIndex:5];
 
                         [encoder setThreadgroupMemoryLength:8192 atIndex:0];
@@ -5306,8 +5023,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
 
                 GGML_ASSERT(ne01 < 65536);
 
-                // use non-vec kernel if the batch size is large or if the vec-kernel is not supported for this head size
-                if (ne01 >= 20 || (ne00 % 32 != 0)) {
+                if (!ggml_metal_flash_attn_ext_use_vec(dst)) {
                     // half8x8 kernel
                     const int64_t nqptg = 8;  // queries per threadgroup    !! sync with kernel template arguments !!
                     const int64_t ncpsg = 64; // cache values per simdgroup !! sync with kernel template arguments !!
@@ -5532,34 +5248,20 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
                         GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
                         GGML_ASSERT(ne1*ne2*ne3 <= (1u << 31));
 
-                        // using mem pool allocations with enabled concurrency is not safe [TAG_MEM_POOL_REMOVE]
-                        // still, we assume that concurrent FA won't happen before we do the refactor
-                        //ggml_metal_encode_concurrency_reset(ctx_enc);
-
-                        const int32_t nrows = ne1*ne2*ne3;
-
-                        // temp buffer for writing the results from each workgroup
-                        // - ne20: the size of the head vector
-                        // -  + 2: the S and M values for each intermediate result
-                        const size_t s_tmp = ggml_type_size(GGML_TYPE_F32)*(nrows*nwg*(ne20 + 2));
-                        id<MTLBuffer> h_tmp = ggml_metal_mem_pool_alloc(mem_pool, s_tmp);
-                        if (!h_tmp) {
-                            GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tmp);
-                            return 0;
-                        }
-
-                        //printf("ne01 = %d, ne02 = %d, ne03 = %d, ne20 = %d\n", ne01, ne02, ne03, ne20);
-                        //printf("needed memory: %.3f MiB\n", (float) (ne01*ne02*ne03*ne20*sizeof(float))/1024.0f/1024.0f);
-
-                        [encoder setBuffer:h_tmp offset:0 atIndex:6];
+                        // write the results from each workgroup into a temp buffer
+                        const size_t offs_tmp = offs_dst + ggml_nbytes(dst);
+                        [encoder setBuffer:id_dst offset:offs_tmp atIndex:6];
 
                         [encoder setThreadgroupMemoryLength:smem atIndex:0];
                         [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
 
+                        // sync the 2 kernels
                         ggml_metal_encode_concurrency_reset(ctx_enc);
 
                         // reduce the results from the workgroups
                         {
+                            const int32_t nrows = ne1*ne2*ne3;
+
                             ggml_metal_kargs_flash_attn_ext_vec_reduce args0 = {
                                 nrows,
                             };
@@ -5568,7 +5270,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
 
                             [encoder setComputePipelineState:pipeline0];
                             [encoder setBytes:&args0   length:sizeof(args0) atIndex:0];
-                            [encoder setBuffer:h_tmp   offset:0             atIndex:1];
+                            [encoder setBuffer:id_dst  offset:offs_tmp      atIndex:1];
                             [encoder setBuffer:id_dst  offset:offs_dst      atIndex:2];
 
                             //printf("ne1 = %d, ne2 = %d, ne3 = %d, ne20 = %d\n", ne1, ne2, ne3, ne20);
@@ -5895,12 +5597,7 @@ static enum ggml_status ggml_metal_graph_compute(
         // the main thread commits the first few commands immediately
         // cmd_buf[n_cb]
         {
-            // cannot use commandBufferWithUnretainedReferences because the buffers from the memory pool can get destroyed
-            // TODO: when the memory pools are removed, we can again use commandBufferWithUnretainedReferences
-            //       https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2334215009
-            // [TAG_MEM_POOL_REMOVE]
-            //id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
-            id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBuffer];
+            id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
             [cmd_buf retain];
 
             if (ctx->cmd_bufs[n_cb].obj) {
@@ -5919,8 +5616,7 @@ static enum ggml_status ggml_metal_graph_compute(
         // prepare the rest of the command buffers asynchronously (optional)
         // cmd_buf[0.. n_cb)
         for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
-            //id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
-            id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBuffer];
+            id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
             [cmd_buf retain];
 
             if (ctx->cmd_bufs[cb_idx].obj) {
@@ -6377,6 +6073,31 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
     return ggml_backend_buffer_init(buft, buf_i, ctx, size);
 }
 
+static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
+    size_t res = ggml_nbytes(tensor);
+
+    // some operations require additional memory for fleeting data:
+    switch (tensor->op) {
+        case GGML_OP_MUL_MAT_ID:
+            {
+                res += ggml_metal_mul_mat_id_extra_tpe(tensor);
+                res += ggml_metal_mul_mat_id_extra_ids(tensor);
+            } break;
+        case GGML_OP_FLASH_ATTN_EXT:
+            {
+                if (ggml_metal_flash_attn_ext_use_vec(tensor)) {
+                    res += ggml_metal_flash_attn_ext_extra_tmp(tensor);
+                }
+            } break;
+        default:
+            break;
+    }
+
+    return res;
+
+    GGML_UNUSED(buft);
+}
+
 // default (shared) buffer type
 
 static const char * ggml_backend_metal_buffer_type_shared_get_name(ggml_backend_buffer_type_t buft) {
@@ -6401,6 +6122,10 @@ static size_t ggml_backend_metal_buffer_type_shared_get_max_size(ggml_backend_bu
     return max_size;
 }
 
+static size_t ggml_backend_metal_buffer_type_shared_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
+    return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor);
+}
+
 static bool ggml_backend_metal_buffer_type_shared_is_host(ggml_backend_buffer_type_t buft) {
     return false;
 
@@ -6414,7 +6139,7 @@ static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_shared(void) {
             /* .alloc_buffer     = */ ggml_backend_metal_buffer_type_shared_alloc_buffer,
             /* .get_alignment    = */ ggml_backend_metal_buffer_type_shared_get_alignment,
             /* .get_max_size     = */ ggml_backend_metal_buffer_type_shared_get_max_size,
-            /* .get_alloc_size   = */ NULL, // defaults to ggml_nbytes
+            /* .get_alloc_size   = */ ggml_backend_metal_buffer_type_shared_get_alloc_size,
             /* .is_host          = */ ggml_backend_metal_buffer_type_shared_is_host,
         },
         /* .device  = */ &g_ggml_backend_metal_device,
@@ -6448,6 +6173,10 @@ static size_t ggml_backend_metal_buffer_type_private_get_max_size(ggml_backend_b
     return max_size;
 }
 
+static size_t ggml_backend_metal_buffer_type_private_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
+    return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor);
+}
+
 static bool ggml_backend_metal_buffer_type_private_is_host(ggml_backend_buffer_type_t buft) {
     return false;
 
@@ -6461,7 +6190,7 @@ static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_private(void) {
             /* .alloc_buffer     = */ ggml_backend_metal_buffer_type_private_alloc_buffer,
             /* .get_alignment    = */ ggml_backend_metal_buffer_type_private_get_alignment,
             /* .get_max_size     = */ ggml_backend_metal_buffer_type_private_get_max_size,
-            /* .get_alloc_size   = */ NULL, // defaults to ggml_nbytes
+            /* .get_alloc_size   = */ ggml_backend_metal_buffer_type_private_get_alloc_size,
             /* .is_host          = */ ggml_backend_metal_buffer_type_private_is_host,
         },
         /* .device  = */ &g_ggml_backend_metal_device,
@@ -6496,6 +6225,10 @@ static size_t ggml_backend_metal_buffer_type_mapped_get_max_size(ggml_backend_bu
     return max_size;
 }
 
+static size_t ggml_backend_metal_buffer_type_mapped_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
+    return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor);
+}
+
 static bool ggml_backend_metal_buffer_type_mapped_is_host(ggml_backend_buffer_type_t buft) {
     return false;
 
@@ -6511,7 +6244,7 @@ static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_mapped(void) {
             /* .alloc_buffer     = */ ggml_backend_metal_buffer_type_mapped_alloc_buffer,
             /* .get_alignment    = */ ggml_backend_metal_buffer_type_mapped_get_alignment,
             /* .get_max_size     = */ ggml_backend_metal_buffer_type_mapped_get_max_size,
-            /* .get_alloc_size   = */ NULL, // defaults to ggml_nbytes
+            /* .get_alloc_size   = */ ggml_backend_metal_buffer_type_mapped_get_alloc_size,
             /* .is_host          = */ ggml_backend_metal_buffer_type_mapped_is_host,
         },
         /* .device  = */ &g_ggml_backend_metal_device,
@@ -6711,11 +6444,8 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
         const int n_nodes_per_cb = ctx->n_nodes_per_cb;
 
         id<MTLCommandBuffer>         cmd_buf    = ctx->cmd_bufs[cb_idx].obj;
-        struct ggml_metal_mem_pool * mem_pool   = ctx->cmd_bufs[cb_idx].mem_pool;
         struct ggml_mem_ranges     * mem_ranges = ctx->cmd_bufs[cb_idx].mem_ranges;
 
-        ggml_metal_mem_pool_reset(mem_pool);
-
         if (mem_ranges) {
             ggml_mem_ranges_reset(mem_ranges);
         }
@@ -6743,7 +6473,6 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
         struct ggml_metal_encode_context ctx_enc = {
             /*.backend    =*/ backend,
             /*.encoder    =*/ encoder,
-            /*.mem_pool   =*/ mem_pool,
             /*.mem_ranges =*/ mem_ranges,
         };