struct ggml_cgraph * gf;
// the callback given to the thread pool
- // TODO: ideally, this should be created once, utilizing the command buffer state above
- // for some reason, doing it like this leads to a crash
void (^encode_async)(size_t ith);
// n_cb command buffers + 1 used by the main thread
[ctx->kernels[i].pipeline release];
}
+ Block_release(ctx->encode_async);
+
[ctx->queue release];
[ctx->device release];
}
}
- // TODO: how to avoid this allocation? I tried initializing it in ggml_backend_metal_set_n_cb but it crashes.
- ctx->encode_async = ^(size_t iter) {
- const int cb_idx = iter;
- const int n_cb_l = ctx->n_cb;
-
- const int n_nodes_0 = ctx->n_nodes_0;
- const int n_nodes_1 = ctx->n_nodes_1;
-
- const int n_nodes_per_cb = ctx->n_nodes_per_cb;
-
- id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
- id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
-
- int node_start = 0;
- int node_end = n_nodes_0;
-
- if (cb_idx < n_cb_l) {
- node_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb);
- node_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1));
- }
-
- for (int idx = node_start; idx < node_end; ++idx) {
- if (should_capture) {
- [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(gf, idx)) encoding:NSUTF8StringEncoding]];
- }
-
- ggml_metal_encode_node(ctx, idx, encoder);
-
- if (should_capture) {
- [encoder popDebugGroup];
- }
- }
-
- [encoder endEncoding];
-
- if (cb_idx < 2 || ctx->abort_callback == NULL) {
- [command_buffer commit];
- }
- };
-
// the main thread commits the first few commands immediately
// command_buffer[n_cb]
{
// default buffer
static id<MTLDevice> g_backend_device = nil;
-static int g_backend_device_ref_count = 0; // TODO: make thread-safe
+static int g_backend_device_ref_count = 0;
static id<MTLDevice> ggml_backend_metal_get_device(void) {
if (g_backend_device == nil) {
}
}
- // TODO: setting encode_async here causes crash during the next ggml_metal_graph_compute call. why?
- //ctx->encode_async = ^(size_t iter) {
- // ...
- //};
+ if (ctx->encode_async) {
+ Block_release(ctx->encode_async);
+ }
+
+ ctx->encode_async = Block_copy(^(size_t iter) {
+ const int cb_idx = iter;
+ const int n_cb_l = ctx->n_cb;
+
+ const int n_nodes_0 = ctx->n_nodes_0;
+ const int n_nodes_1 = ctx->n_nodes_1;
+
+ const int n_nodes_per_cb = ctx->n_nodes_per_cb;
+
+ id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
+ id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
+
+ int node_start = 0;
+ int node_end = n_nodes_0;
+
+ if (cb_idx < n_cb_l) {
+ node_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb);
+ node_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1));
+ }
+
+ const bool should_capture = ctx->capture_next_compute;
+
+ for (int idx = node_start; idx < node_end; ++idx) {
+ if (should_capture) {
+ [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
+ }
+
+ ggml_metal_encode_node(ctx, idx, encoder);
+
+ if (should_capture) {
+ [encoder popDebugGroup];
+ }
+ }
+
+ [encoder endEncoding];
+
+ if (cb_idx < 2 || ctx->abort_callback == NULL) {
+ [command_buffer commit];
+ }
+ });
}
static struct ggml_backend_i ggml_backend_metal_i = {