// abort ggml_metal_graph_compute if callback returns true
ggml_abort_callback abort_callback;
void * abort_callback_data;
+
+ // error state - set when a command buffer fails during synchronize
+ // once set, graph_compute will return GGML_STATUS_FAILED until the backend is recreated
+ bool has_error;
};
ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) {
res->capture_started = false;
res->capture_scope = nil;
+ res->has_error = false;
+
res->gf = nil;
res->encode_async = nil;
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
if (status == MTLCommandBufferStatusError) {
GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
}
- GGML_ABORT("fatal error");
+ ctx->has_error = true;
+ return;
}
}
}
if (status == MTLCommandBufferStatusError) {
GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
}
- GGML_ABORT("fatal error");
+
+ // release this and all remaining command buffers before returning
+ for (size_t j = i; j < ctx->cmd_bufs_ext.count; ++j) {
+ [ctx->cmd_bufs_ext[j] release];
+ }
+ [ctx->cmd_bufs_ext removeAllObjects];
+
+ ctx->has_error = true;
+ return;
}
[cmd_buf release];
}
enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * gf) {
+ if (ctx->has_error) {
+ GGML_LOG_ERROR("%s: backend is in error state from a previous command buffer failure - recreate the backend to recover\n", __func__);
+ return GGML_STATUS_FAILED;
+ }
+
// number of nodes encoded by the main thread (empirically determined)
const int n_main = MAX(64, 0.1*gf->n_nodes);