]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : add abort callback (ggml/905)
authorConrad Kramer <redacted>
Wed, 7 Aug 2024 06:55:49 +0000 (02:55 -0400)
committerGeorgi Gerganov <redacted>
Thu, 8 Aug 2024 10:19:30 +0000 (13:19 +0300)
ggml/include/ggml-metal.h
ggml/src/ggml-metal.m

index 6c3226c37e0ef48072ddf146d73a392cf1d9fab1..d483cf1ac40c6e91590599cb1196e68ad56a2615 100644 (file)
@@ -50,6 +50,8 @@ GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void
 
 GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb);
 
+GGML_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data);
+
 GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
 
 // helper to check if the device supports a specific family
index b512eb0be132e7dfd5dac5d7e4dcee0f1fce5fe9..c19274176cc1e02b045bcefe41a2c65832e35133 100644 (file)
@@ -224,6 +224,10 @@ struct ggml_metal_context {
     bool support_simdgroup_mm;
 
     bool should_capture_next_compute;
+
+    // abort ggml_metal_graph_compute if callback returns true
+    ggml_abort_callback abort_callback;
+    void *              abort_callback_data;
 };
 
 // MSL code
@@ -878,8 +882,11 @@ static enum ggml_status ggml_metal_graph_compute(
         id<MTLCommandBuffer> command_buffer  = [ctx->queue commandBufferWithUnretainedReferences];
         command_buffer_builder[cb_idx] = command_buffer;
 
-        // enqueue the command buffers in order to specify their execution order
-        [command_buffer enqueue];
+        // always enqueue the first two command buffers
+        // enqueue all of the command buffers if we don't need to abort
+        if (cb_idx < 2 || ctx->abort_callback == NULL) {
+            [command_buffer enqueue];
+        }
     }
 
     const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
@@ -2827,7 +2834,9 @@ static enum ggml_status ggml_metal_graph_compute(
 
         [encoder endEncoding];
 
-        [command_buffer commit];
+        if (cb_idx < 2 || ctx->abort_callback == NULL) {
+            [command_buffer commit];
+        }
     });
 
     // Wait for completion and check status of each command buffer
@@ -2847,6 +2856,23 @@ static enum ggml_status ggml_metal_graph_compute(
 
             return GGML_STATUS_FAILED;
         }
+
+        id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? command_buffers[i + 1] : nil);
+        if (!next_buffer) {
+            continue;
+        }
+
+        bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
+        if (next_queued) {
+            continue;
+        }
+
+        if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) {
+            GGML_METAL_LOG_INFO("%s: command buffer %d aborted", __func__, i);
+            return GGML_STATUS_ABORTED;
+        }
+
+        [next_buffer commit];
     }
 
     if (should_capture) {
@@ -3242,6 +3268,15 @@ void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
     ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
 }
 
+void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data) {
+    GGML_ASSERT(ggml_backend_is_metal(backend));
+
+    struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
+
+    ctx->abort_callback = abort_callback;
+    ctx->abort_callback_data = user_data;
+}
+
 bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
     GGML_ASSERT(ggml_backend_is_metal(backend));