]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : localized logic in `ggml_metal_graph_compute` (#4924)
authorPaul Tsochantaris <redacted>
Tue, 16 Jan 2024 17:05:19 +0000 (17:05 +0000)
committerGitHub <redacted>
Tue, 16 Jan 2024 17:05:19 +0000 (19:05 +0200)
* Metal: Localized logic in `ggml_metal_graph_compute`, minor performance improvement

* Whitespace

* Collecting command buffer completions on single thread

* Whitespace

* Reduce diff noise

ggml-metal.h
ggml-metal.m

index 8b0bfc5f10329babbcab0e8c7e2985fadb5af3f0..df83a1807c6b2a0dc375d2852654426811fd5298 100644 (file)
@@ -27,7 +27,6 @@
 
 // max memory buffers that can be mapped to the device
 #define GGML_METAL_MAX_BUFFERS 64
-#define GGML_METAL_MAX_COMMAND_BUFFERS 32
 
 struct ggml_tensor;
 struct ggml_cgraph;
index c21dc465ae50cfdd655025c24f5a064c2428a697..a549e6713e9bcb9c5d718311c7096f0975585885 100644 (file)
@@ -170,9 +170,6 @@ struct ggml_metal_context {
     id<MTLCommandQueue> queue;
     id<MTLLibrary>      library;
 
-    id<MTLCommandBuffer>         command_buffers [GGML_METAL_MAX_COMMAND_BUFFERS];
-    id<MTLComputeCommandEncoder> command_encoders[GGML_METAL_MAX_COMMAND_BUFFERS];
-
     dispatch_queue_t d_queue;
 
     int n_buffers;
@@ -719,25 +716,25 @@ static bool ggml_metal_graph_compute(
     @autoreleasepool {
 
     MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
-
-    const int n_nodes  = gf->n_nodes;
     edesc.dispatchType = MTLDispatchTypeSerial;
 
     // create multiple command buffers and enqueue them
     // then, we encode the graph into the command buffers in parallel
 
+    const int n_nodes  = gf->n_nodes;
     const int n_cb = ctx->n_cb;
+    const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
 
-    for (int i = 0; i < n_cb; ++i) {
-        ctx->command_buffers[i] = [ctx->queue commandBuffer];
+    id<MTLCommandBuffer> command_buffer_builder[n_cb];
+    for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
+        id<MTLCommandBuffer> command_buffer  = [ctx->queue commandBufferWithUnretainedReferences];
+        command_buffer_builder[cb_idx] = command_buffer;
 
         // enqueue the command buffers in order to specify their execution order
-        [ctx->command_buffers[i] enqueue];
-
-        ctx->command_encoders[i] = [ctx->command_buffers[i] computeCommandEncoderWithDescriptor: edesc];
+        [command_buffer enqueue];
     }
+    const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
 
-    const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
     dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) {
         const int cb_idx = iter;
 
@@ -745,15 +742,13 @@ static bool ggml_metal_graph_compute(
         size_t offs_src1 = 0;
         size_t offs_dst  = 0;
 
-        id<MTLCommandBuffer> command_buffer  = ctx->command_buffers[cb_idx];
-        id<MTLComputeCommandEncoder> encoder = ctx->command_encoders[cb_idx];
+        id<MTLCommandBuffer> command_buffer  = command_buffers[cb_idx];
+        id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
 
         const int node_start =                                      (cb_idx + 0) * n_nodes_per_cb;
         const int node_end   = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
 
-        for (int ind = node_start; ind < node_end; ++ind) {
-            const int i = ind;
-
+        for (int i = node_start; i < node_end; ++i) {
             if (i == -1) {
                 [encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
                 continue;
@@ -2249,12 +2244,14 @@ static bool ggml_metal_graph_compute(
         [command_buffer commit];
     });
 
-    // check status of command buffers
+    // Wait for completion and check status of each command buffer
     // needed to detect if the device ran out-of-memory for example (#1881)
-    for (int i = 0; i < n_cb; i++) {
-        [ctx->command_buffers[i] waitUntilCompleted];
 
-        MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
+    for (int i = 0; i < n_cb; ++i) {
+        id<MTLCommandBuffer> command_buffer = command_buffers[i];
+        [command_buffer waitUntilCompleted];
+
+        MTLCommandBufferStatus status = [command_buffer status];
         if (status != MTLCommandBufferStatusCompleted) {
             GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
             return false;