]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
metal : add debug capture backend function (ggml/694)
authorJack Mousseau <redacted>
Mon, 29 Jan 2024 09:22:23 +0000 (01:22 -0800)
committerGeorgi Gerganov <redacted>
Tue, 30 Jan 2024 19:27:58 +0000 (21:27 +0200)
Co-authored-by: Georgi Gerganov <redacted>
ggml-metal.h
ggml-metal.m

index 8b0bfc5f10329babbcab0e8c7e2985fadb5af3f0..e8ceb1bd762f75b62949eab94b831a7f67b6d6ed 100644 (file)
@@ -58,6 +58,9 @@ GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(voi
 // ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
 GGML_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family);
 
+// capture all command buffers committed the next time `ggml_backend_graph_compute` is called
+GGML_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend);
+
 #ifdef __cplusplus
 }
 #endif
index a0efda0baa2d572af2dd27bda90172dfbd269594..e5fa14029ce50189d7125b6db416dcf9924eb7e1 100644 (file)
@@ -167,6 +167,8 @@ struct ggml_metal_context {
 
     bool support_simdgroup_reduction;
     bool support_simdgroup_mm;
+
+    bool should_capture_next_compute;
 };
 
 // MSL code
@@ -684,6 +686,20 @@ static bool ggml_metal_graph_compute(
     const int n_cb = ctx->n_cb;
     const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
 
+    const bool should_capture = ctx->should_capture_next_compute;
+    if (should_capture) {
+        ctx->should_capture_next_compute = false;
+
+        MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
+        descriptor.captureObject = ctx->queue;
+
+        NSError * error = nil;
+        if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
+            GGML_METAL_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
+            GGML_ASSERT(!"capture failed");
+        }
+    }
+
     id<MTLCommandBuffer> command_buffer_builder[n_cb];
     for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
         id<MTLCommandBuffer> command_buffer  = [ctx->queue commandBufferWithUnretainedReferences];
@@ -692,6 +708,7 @@ static bool ggml_metal_graph_compute(
         // enqueue the command buffers in order to specify their execution order
         [command_buffer enqueue];
     }
+
     const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
 
     dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) {
@@ -738,9 +755,9 @@ static bool ggml_metal_graph_compute(
                 GGML_ASSERT(!"unsupported op");
             }
 
-#ifndef GGML_METAL_NDEBUG
-            [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]];
-#endif
+            if (should_capture) {
+                [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]];
+            }
 
             const int64_t  ne00 = src0 ? src0->ne[0] : 0;
             const int64_t  ne01 = src0 ? src0->ne[1] : 0;
@@ -2190,9 +2207,9 @@ static bool ggml_metal_graph_compute(
                     }
             }
 
-#ifndef GGML_METAL_NDEBUG
-            [encoder popDebugGroup];
-#endif
+            if (should_capture) {
+                [encoder popDebugGroup];
+            }
         }
 
         [encoder endEncoding];
@@ -2214,6 +2231,10 @@ static bool ggml_metal_graph_compute(
         }
     }
 
+    if (should_capture) {
+        [[MTLCaptureManager sharedCaptureManager] stopCapture];
+    }
+
     return true;
 }
 
@@ -2575,6 +2596,13 @@ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
     return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
 }
 
+void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
+    GGML_ASSERT(ggml_backend_is_metal(backend));
+
+    struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
+    ctx->should_capture_next_compute = true;
+}
+
 GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
 
 GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {