bool support_simdgroup_mm;
bool should_capture_next_compute;
+
+ // abort ggml_metal_graph_compute if callback returns true
+ ggml_abort_callback abort_callback;
+ void * abort_callback_data;
};
// MSL code
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
command_buffer_builder[cb_idx] = command_buffer;
- // enqueue the command buffers in order to specify their execution order
- [command_buffer enqueue];
+ // always enqueue the first two command buffers
+ // enqueue all of the command buffers if we don't need to abort
+ if (cb_idx < 2 || ctx->abort_callback == NULL) {
+ [command_buffer enqueue];
+ }
}
const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
[encoder endEncoding];
- [command_buffer commit];
+ if (cb_idx < 2 || ctx->abort_callback == NULL) {
+ [command_buffer commit];
+ }
});
// Wait for completion and check status of each command buffer
return GGML_STATUS_FAILED;
}
+
+ id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? command_buffers[i + 1] : nil);
+ if (!next_buffer) {
+ continue;
+ }
+
+ bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
+ if (next_queued) {
+ continue;
+ }
+
+ if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) {
+ GGML_METAL_LOG_INFO("%s: command buffer %d aborted", __func__, i);
+ return GGML_STATUS_ABORTED;
+ }
+
+ [next_buffer commit];
}
if (should_capture) {
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
}
+void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data) {
+ GGML_ASSERT(ggml_backend_is_metal(backend));
+
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
+
+ ctx->abort_callback = abort_callback;
+ ctx->abort_callback_data = user_data;
+}
+
bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
GGML_ASSERT(ggml_backend_is_metal(backend));