]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : add memory pool for temp allocs (#12850)
authorGeorgi Gerganov <redacted>
Tue, 22 Apr 2025 13:15:51 +0000 (16:15 +0300)
committerGitHub <redacted>
Tue, 22 Apr 2025 13:15:51 +0000 (16:15 +0300)
* metal : add memory pool for temp allocs (wip) [no ci]

* cont : free buffers from the heap

* cont : resize heap [no ci]

* cont : refactor heap [no ci]

* cont : heap for each cmd buffer [no ci]

* cont : fix free

* wip

* cont : fix alignment [no ci]

* cont : not working .. [no ci]

* cont : heap allocation now works [no ci]

* cont : use MTLHeapTypePlacement

ggml-ci

* metal : use dynamic MTLHeap allocations

ggml-ci

* metal : add comments

* metal : disable softmax use of mem_pool

ggml-ci

* metal : final touches

ggml/src/ggml-metal/ggml-metal.m

index 266d8af4693c211526cd8da2d3ff581abe061f32..d92392edb7eb1a0ce8095074b40309079492982f 100644 (file)
@@ -44,8 +44,8 @@ static struct ggml_backend_device g_ggml_backend_metal_device;
 // note: assumes single GPU device - the default one
 // TODO: support multiple GPU devices
 static struct ggml_backend_metal_device_context {
-    id<MTLDevice> mtl_device;
-    int           mtl_device_ref_count;
+    id<MTLDevice>  mtl_device;
+    int            mtl_device_ref_count;
     id<MTLLibrary> mtl_library;
 
     bool has_simdgroup_reduction;
@@ -490,7 +490,259 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_COUNT
 };
 
+//
+// ggml_metal_heap
+//
+
+struct ggml_metal_heap {
+    // number of times the heap was unused
+    int n_unused;
+
+    // total number of buffer allocations in this heap across all computes
+    int64_t n_alloc;
+
+    // current offset in the heap - we reset this after each node in order to reuse the memory
+    size_t offs;
+
+    // the currently allocated MTLBuffer objects in this heap
+    id<MTLHeap> obj;
+
+    NSMutableArray * bufs;
+};
+
+static struct ggml_metal_heap * ggml_metal_heap_init(id<MTLDevice> device, size_t size) {
+    struct ggml_metal_heap * heap = calloc(1, sizeof(struct ggml_metal_heap));
+
+    MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc] init];
+    desc.storageMode  = MTLStorageModePrivate;
+    desc.cpuCacheMode = MTLCPUCacheModeDefaultCache;
+    desc.type         = MTLHeapTypePlacement;
+    desc.size         = size;
+
+    heap->n_unused = 0;
+    heap->n_alloc = 0;
+
+    heap->obj = [device newHeapWithDescriptor:desc];
+    if (!heap->obj) {
+        GGML_LOG_ERROR("%s: error: failed to create MTLHeap with size %zu\n", __func__, size);
+
+        free(heap);
+
+        return false;
+    }
+
+    [desc release];
+
+    heap->bufs = [[NSMutableArray alloc] init];
+
+    return heap;
+}
+
+static void ggml_metal_heap_reset(struct ggml_metal_heap * heap) {
+    heap->offs = 0;
+
+    // count how many graph computes the heap ended up being unused
+    if ([heap->bufs count] > 0) {
+        heap->n_unused = 0;
+    } else {
+        heap->n_unused++;
+    }
+
+    for (id<MTLBuffer> buf in heap->bufs) {
+        [buf release];
+    }
+    [heap->bufs removeAllObjects];
+
+    // tell the OS that it can reuse this memory if needed
+    // ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
+    [heap->obj setPurgeableState:MTLPurgeableStateVolatile];
+}
+
+static void ggml_metal_heap_free(struct ggml_metal_heap * heap) {
+    if (heap == nil) {
+        return;
+    }
+
+    ggml_metal_heap_reset(heap);
+
+    [heap->obj  release];
+    [heap->bufs release];
+
+    free(heap);
+}
+
+@interface ggml_metal_heap_ptr : NSObject
+
+@property (nonatomic, assign) struct ggml_metal_heap * data;
+
+@end
+
+@implementation ggml_metal_heap_ptr
+@end
+
+//
+// ggml_metal_mem_pool
+//
+
+struct ggml_metal_mem_pool {
+    id<MTLDevice> device;
+
+    int n_heaps; // total number of heaps ever created (including those that were removed)
+
+    NSMutableArray * heaps;
+    NSMutableArray * heaps_to_remove;
+};
+
+static struct ggml_metal_mem_pool * ggml_metal_mem_pool_init(void) {
+    struct ggml_metal_mem_pool * mem_pool = calloc(1, sizeof(struct ggml_metal_mem_pool));
+
+    mem_pool->n_heaps = 0;
+
+    mem_pool->heaps           = [[NSMutableArray alloc] init];
+    mem_pool->heaps_to_remove = [[NSMutableArray alloc] init];
+
+    return mem_pool;
+}
+
+static void ggml_metal_mem_pool_free(struct ggml_metal_mem_pool * mem_pool) {
+    GGML_LOG_DEBUG("%s: freeing memory pool, num heaps = %zu (total = %d)\n", __func__, [mem_pool->heaps count], mem_pool->n_heaps);
+
+    size_t size_all = 0;
+    size_t size_cur = 0;
+
+    for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
+        GGML_LOG_DEBUG("%s:   heap: %p\n",                __func__, (void *) ptr.data);
+        GGML_LOG_DEBUG("%s:     n_alloc:  %" PRId64 "\n", __func__, ptr.data->n_alloc);
+        GGML_LOG_DEBUG("%s:     n_unused: %d\n",          __func__, ptr.data->n_unused);
+        GGML_LOG_DEBUG("%s:     size:     %.2f MiB\n",    __func__, [ptr.data->obj size] / 1024.0 / 1024.0);
+        GGML_LOG_DEBUG("%s:     bufs:     %zu\n",         __func__, [ptr.data->bufs count]);
+
+        if ([ptr.data->bufs count] > 0) {
+            size_cur += [ptr.data->obj size];
+        }
+        size_all += [ptr.data->obj size];
+
+        ggml_metal_heap_free(ptr.data);
+        [ptr release];
+    }
+    [mem_pool->heaps           release];
+    [mem_pool->heaps_to_remove release];
+
+    if (size_all > 0) {
+        GGML_LOG_DEBUG("%s:   size_all: %.2f MiB\n", __func__, size_all / 1024.0 / 1024.0);
+        GGML_LOG_DEBUG("%s:   size_cur: %.2f MiB\n", __func__, size_cur / 1024.0 / 1024.0);
+    }
+
+    free(mem_pool);
+}
+
+static void ggml_metal_mem_pool_reset(struct ggml_metal_mem_pool * mem_pool) {
+    for (NSUInteger i = 0; i < [mem_pool->heaps count]; i++) {
+        ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:i];
+
+        struct ggml_metal_heap * heap = ptr.data;
+        ggml_metal_heap_reset(heap);
+
+        // if the heap hasn't been used for a while, remove it
+        if (heap->n_unused >= 128) {
+            [mem_pool->heaps_to_remove addObject:@(i)];
+        }
+    }
+
+    if (mem_pool->heaps_to_remove.count > 0) {
+        for (NSUInteger i = 0; i < [mem_pool->heaps_to_remove count]; i++) {
+            NSUInteger index = [[mem_pool->heaps_to_remove objectAtIndex:i] intValue];
+            ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:index];
+
+            struct ggml_metal_heap * heap = ptr.data;
+            ggml_metal_heap_free(heap);
+
+            [mem_pool->heaps removeObjectAtIndex:index];
+            [ptr release];
+        }
+
+        [mem_pool->heaps_to_remove removeAllObjects];
+    }
+}
+
+static void ggml_metal_mem_pool_clear(struct ggml_metal_mem_pool * mem_pool) {
+    for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
+        ptr.data->offs = 0;
+    }
+}
+
+static id<MTLBuffer> ggml_metal_mem_pool_alloc(struct ggml_metal_mem_pool * mem_pool, size_t size) {
+    const size_t alignment = 32;
+
+    const size_t size_aligned = GGML_PAD(size, alignment);
+
+    // try one of the existing heaps
+    for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
+        struct ggml_metal_heap * heap = ptr.data;
+        if (heap->offs + size_aligned <= [heap->obj size]) {
+            // if this is the first buffer in the heap for the current command buffer, tell the OS that
+            //   it cannot free the memory used by the heap
+            // ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
+            if ([heap->bufs count] == 0) {
+                [heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
+            }
+
+            id<MTLBuffer> buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
+            if (buf == nil) {
+                GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
+                return nil;
+            }
+
+            heap->n_alloc++;
+            heap->offs += size_aligned;
+
+            [heap->bufs addObject:buf];
+
+            return buf;
+        }
+    }
+
+    // create a new heap that can fit this buffer
+    ggml_metal_heap_ptr * heap_ptr = [ggml_metal_heap_ptr new];
+
+    struct ggml_metal_heap * heap = ggml_metal_heap_init(mem_pool->device, size_aligned);
+    if (heap == NULL) {
+        GGML_LOG_ERROR("%s: error: failed to create heap of size %zu\n", __func__, size_aligned);
+        return NULL;
+    }
+
+    //GGML_LOG_DEBUG("%s: creating new heap of size %zu, got %zu\n", __func__, size_aligned, [heap->obj size]);
+
+    heap_ptr.data = heap;
+    ggml_metal_heap_reset(heap);
+
+    [heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
+    id<MTLBuffer> buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
+    if (buf == nil) {
+        GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
+        return NULL;
+    }
+
+    heap->n_alloc++;
+    heap->offs += size_aligned;
+
+    [heap->bufs addObject:buf];
+
+    [mem_pool->heaps addObject:heap_ptr];
+    mem_pool->n_heaps++;
+
+    return buf;
+}
+
+struct ggml_metal_command_buffer {
+    id<MTLCommandBuffer> obj;
+
+    // each command buffer has a memory pool from which it can allocate temporary buffers during the compute
+    struct ggml_metal_mem_pool * mem_pool;
+};
+
 struct ggml_backend_metal_context {
+    id<MTLDevice>       device;
     id<MTLCommandQueue> queue;
 
     dispatch_queue_t d_queue;
@@ -515,7 +767,7 @@ struct ggml_backend_metal_context {
     void (^encode_async)(size_t ith);
 
     // n_cb command buffers + 1 used by the main thread
-    id<MTLCommandBuffer> command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
+    struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
 
     // abort ggml_metal_graph_compute if callback returns true
     ggml_abort_callback abort_callback;
@@ -705,9 +957,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
     struct ggml_backend_metal_device_context * ctx_dev = dev->context;
 
     id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
+
     GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
 
-    ctx->queue  = [device newCommandQueue];
+    ctx->device = device;
+    ctx->queue = [device newCommandQueue];
     if (ctx->queue == nil) {
         GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
         return NULL;
@@ -768,7 +1022,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
     ctx->gf = nil;
     ctx->encode_async = nil;
     for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
-        ctx->command_buffers[i] = nil;
+        ctx->cmd_bufs[i].obj = nil;
+
+        ctx->cmd_bufs[i].mem_pool = ggml_metal_mem_pool_init();
+        ctx->cmd_bufs[i].mem_pool->device = device;
     }
 
 #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
@@ -1181,6 +1438,12 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
 
     [ctx->queue release];
 
+    for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
+        // ctx->cmd_bufs[i].obj is auto released
+
+        ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool);
+    }
+
     dispatch_release(ctx->d_queue);
 
     free(ctx);
@@ -1486,10 +1749,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
     }
 }
 
-static void ggml_metal_encode_node(
+static bool ggml_metal_encode_node(
                         ggml_backend_t   backend,
                                    int   idx,
-          id<MTLComputeCommandEncoder>   encoder) {
+          id<MTLComputeCommandEncoder>   encoder,
+            struct ggml_metal_mem_pool * mem_pool) {
     struct ggml_backend_metal_context        * ctx     = backend->context;
     struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
 
@@ -1505,7 +1769,7 @@ static void ggml_metal_encode_node(
     struct ggml_tensor * dst  = node;
 
     if (ggml_is_empty(dst)) {
-        return;
+        return true;
     }
 
     switch (dst->op) {
@@ -1516,7 +1780,7 @@ static void ggml_metal_encode_node(
         case GGML_OP_PERMUTE:
             {
                 // noop -> next node
-            } return;
+            } return true;
         default:
             {
             } break;
@@ -1527,6 +1791,8 @@ static void ggml_metal_encode_node(
         GGML_ABORT("unsupported op");
     }
 
+    ggml_metal_mem_pool_clear(mem_pool);
+
     const int64_t  ne00 = src0 ? src0->ne[0] : 0;
     const int64_t  ne01 = src0 ? src0->ne[1] : 0;
     const int64_t  ne02 = src0 ? src0->ne[2] : 0;
@@ -2173,26 +2439,76 @@ static void ggml_metal_encode_node(
                 const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
                 const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
-                ggml_metal_kargs_soft_max args = {
+// use this branch to test the ggml_metal_mem_pool functionality
+#if 0
+                // cpy to tmp buffer in MTLHeap
+
+                id<MTLBuffer> h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
+                if (!h_src0) {
+                    GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0));
+                    return false;
+                }
+
+                offs_src0 = 0;
+
+                ggml_metal_kargs_cpy args_cpy = {
                     /*.ne00 =*/ ne00,
                     /*.ne01 =*/ ne01,
                     /*.ne02 =*/ ne02,
-                    /*.scale =*/ scale,
-                    /*.max_bias =*/ max_bias,
-                    /*.m0 =*/ m0,
-                    /*.m1 =*/ m1,
+                    /*.ne03 =*/ ne03,
+                    /*.nb00 =*/ nb00,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.nb03 =*/ nb03,
+                    /*.ne0  =*/ ne00,
+                    /*.ne1  =*/ ne01,
+                    /*.ne2  =*/ ne02,
+                    /*.ne3  =*/ ne03,
+                    /*.nb0  =*/ nb00,
+                    /*.nb1  =*/ nb01,
+                    /*.nb2  =*/ nb02,
+                    /*.nb3  =*/ nb03,
+                };
+
+                if (src0->type == GGML_TYPE_F16) {
+                    [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline];
+                } else {
+                    [encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline];
+                }
+                [encoder setBytes:&args_cpy length:sizeof(args_cpy) atIndex:0];
+                [encoder setBuffer:id_src0  offset:offs_src0        atIndex:1];
+                [encoder setBuffer:h_src0   offset:0                atIndex:2];
+
+                GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
+                int nth_cpy = MIN(1024, ne00 / ggml_blck_size(src0->type));
+
+                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth_cpy, 1, 1)];
+
+#else
+                id<MTLBuffer> h_src0 = id_src0;
+#endif
+                // softmax
+
+                ggml_metal_kargs_soft_max args = {
+                    /*.ne00        =*/ ne00,
+                    /*.ne01        =*/ ne01,
+                    /*.ne02        =*/ ne02,
+                    /*.scale       =*/ scale,
+                    /*.max_bias    =*/ max_bias,
+                    /*.m0          =*/ m0,
+                    /*.m1          =*/ m1,
                     /*.n_head_log2 =*/ n_head_log2,
                 };
 
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0 offset:offs_src0   atIndex:0];
+                [encoder setBuffer:h_src0 offset:offs_src0      atIndex:0];
                 if (id_src1) {
-                    [encoder setBuffer:id_src1 offset:offs_src1   atIndex:1];
+                    [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
                 } else {
-                    [encoder setBuffer:id_src0 offset:offs_src0   atIndex:1];
+                    [encoder setBuffer:h_src0 offset:offs_src0  atIndex:1];
                 }
-                [encoder setBuffer:id_dst      offset:offs_dst            atIndex:2];
-                [encoder setBytes:&args        length:sizeof(args)        atIndex:3];
+                [encoder setBuffer:id_dst offset:offs_dst       atIndex:2];
+                [encoder setBytes:&args   length:sizeof(args)   atIndex:3];
 
                 [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
 
@@ -4601,6 +4917,8 @@ static void ggml_metal_encode_node(
                 GGML_ABORT("fatal error");
             }
     }
+
+    return true;
 }
 
 static enum ggml_status ggml_metal_graph_compute(
@@ -4654,25 +4972,25 @@ static enum ggml_status ggml_metal_graph_compute(
         }
 
         // the main thread commits the first few commands immediately
-        // command_buffer[n_cb]
+        // cmd_buf[n_cb]
         {
-            id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
-            ctx->command_buffers[n_cb] = command_buffer;
+            id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
+            ctx->cmd_bufs[n_cb].obj = cmd_buf;
 
-            [command_buffer enqueue];
+            [cmd_buf enqueue];
             ctx->encode_async(n_cb);
         }
 
         // prepare the rest of the command buffers asynchronously
-        // command_buffer[0.. n_cb)
+        // cmd_buf[0.. n_cb)
         for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
-            id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
-            ctx->command_buffers[cb_idx] = command_buffer;
+            id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
+            ctx->cmd_bufs[cb_idx].obj = cmd_buf;
 
             // always enqueue the first two command buffers
             // enqueue all of the command buffers if we don't need to abort
             if (cb_idx < 2 || ctx->abort_callback == NULL) {
-                [command_buffer enqueue];
+                [cmd_buf enqueue];
             }
         }
 
@@ -4681,14 +4999,14 @@ static enum ggml_status ggml_metal_graph_compute(
         // wait for completion and check status of each command buffer
         // needed to detect if the device ran out-of-memory for example (#1881)
         {
-            id<MTLCommandBuffer> command_buffer = ctx->command_buffers[n_cb];
-            [command_buffer waitUntilCompleted];
+            id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
+            [cmd_buf waitUntilCompleted];
 
-            MTLCommandBufferStatus status = [command_buffer status];
+            MTLCommandBufferStatus status = [cmd_buf status];
             if (status != MTLCommandBufferStatusCompleted) {
                 GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
                 if (status == MTLCommandBufferStatusError) {
-                    GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
+                    GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
                 }
 
                 return GGML_STATUS_FAILED;
@@ -4696,20 +5014,20 @@ static enum ggml_status ggml_metal_graph_compute(
         }
 
         for (int i = 0; i < n_cb; ++i) {
-            id<MTLCommandBuffer> command_buffer = ctx->command_buffers[i];
-            [command_buffer waitUntilCompleted];
+            id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
+            [cmd_buf waitUntilCompleted];
 
-            MTLCommandBufferStatus status = [command_buffer status];
+            MTLCommandBufferStatus status = [cmd_buf status];
             if (status != MTLCommandBufferStatusCompleted) {
                 GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
                 if (status == MTLCommandBufferStatusError) {
-                    GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
+                    GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
                 }
 
                 return GGML_STATUS_FAILED;
             }
 
-            id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->command_buffers[i + 1] : nil);
+            id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
             if (!next_buffer) {
                 continue;
             }
@@ -5092,8 +5410,9 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
 
         const int n_nodes_per_cb = ctx->n_nodes_per_cb;
 
-        id<MTLCommandBuffer> command_buffer  = ctx->command_buffers[cb_idx];
-        id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
+        id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
+
+        id<MTLComputeCommandEncoder> encoder = [cmd_buf computeCommandEncoder];
 
         int node_start = 0;
         int node_end   = n_nodes_0;
@@ -5105,22 +5424,29 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
 
         const bool should_capture = ctx->capture_next_compute;
 
+        struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
+        ggml_metal_mem_pool_reset(mem_pool);
+
         for (int idx = node_start; idx < node_end; ++idx) {
             if (should_capture) {
                 [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
             }
 
-            ggml_metal_encode_node(backend, idx, encoder);
+            const bool res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);
 
             if (should_capture) {
                 [encoder popDebugGroup];
             }
+
+            if (!res) {
+                break;
+            }
         }
 
         [encoder endEncoding];
 
         if (cb_idx < 2 || ctx->abort_callback == NULL) {
-            [command_buffer commit];
+            [cmd_buf commit];
         }
     });
 }