]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : add error handling to graph_compute (#1714)
authorFinn Voorhees <redacted>
Wed, 3 Jan 2024 13:39:43 +0000 (08:39 -0500)
committerGitHub <redacted>
Wed, 3 Jan 2024 13:39:43 +0000 (15:39 +0200)
bindings/ruby/ext/ggml-backend-impl.h
bindings/ruby/ext/ggml-backend.c
bindings/ruby/ext/ggml-backend.h
ggml-backend-impl.h
ggml-backend.c
ggml-backend.h
ggml-cuda.cu
ggml-metal.h
ggml-metal.m
whisper.cpp

index 211e3d4247387b2b598caca2e79d6863ffdf4c35..31788cd6baad497835ed8def94a837c37386a726 100644 (file)
@@ -70,7 +70,7 @@ extern "C" {
         void                      (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
 
         // compute graph without a plan
-        void (*graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph);
+        bool (*graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph);
 
         // check if the backend supports an operation
         bool (*supports_op)(ggml_backend_t backend, const struct ggml_tensor * op);
index f6e5fceed0f4df2acf14a44953eb0f51df4443b3..128e33ce63032e41093e3a7140fb30e20a8d6002 100644 (file)
@@ -156,8 +156,8 @@ void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_
     backend->iface.graph_plan_compute(backend, plan);
 }
 
-void ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
-    backend->iface.graph_compute(backend, cgraph);
+bool ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+    return backend->iface.graph_compute(backend, cgraph);
 }
 
 bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
index 966687320ac96d971e9d635429d4ea1018a7ce57..793a0a9d65aadf2463ac896dcfb75c38714d4a14 100644 (file)
@@ -52,7 +52,7 @@ extern "C" {
 
     GGML_API void ggml_backend_graph_plan_free   (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
     GGML_API void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
-    GGML_API void ggml_backend_graph_compute     (ggml_backend_t backend, struct ggml_cgraph * cgraph);
+    GGML_API bool ggml_backend_graph_compute     (ggml_backend_t backend, struct ggml_cgraph * cgraph);
     GGML_API bool ggml_backend_supports_op       (ggml_backend_t backend, const struct ggml_tensor * op);
 
     // tensor copy between different backends
index 05859935a3c2fa23f1b767c6167a7634f1f8c326..ca21b474372a6267ad261ec64da6147372e1cc0d 100644 (file)
@@ -90,7 +90,7 @@ extern "C" {
         void                      (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
 
         // compute graph without a plan
-        void (*graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph);
+        bool (*graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph);
 
         // check if the backend supports an operation
         bool (*supports_op)(ggml_backend_t backend, const struct ggml_tensor * op);
index 2c3752067515fbcf632ec305203152854fe7503e..53e741cb892f8c204c7ced3826cc510f7c48eb87 100644 (file)
@@ -195,11 +195,14 @@ void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_
     ggml_backend_synchronize(backend);
 }
 
-void ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
-    backend->iface.graph_compute(backend, cgraph);
+bool ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+    if (!backend->iface.graph_compute(backend, cgraph)) {
+        return false;
+    }
 
     // TODO: optional sync
     ggml_backend_synchronize(backend);
+    return true;
 }
 
 bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
@@ -597,7 +600,7 @@ static void ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_bac
     GGML_UNUSED(backend);
 }
 
-static void ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+static bool ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
     struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
 
     struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads);
@@ -611,6 +614,7 @@ static void ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_c
     cplan.work_data = cpu_ctx->work_data;
 
     ggml_graph_compute(cgraph, &cplan);
+    return true;
 }
 
 static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
index a9d2fddd726a85e2d326e7e10dfc41e67bf21035..85ff67b0ea843029c17e3745fe3689f04cd39039 100644 (file)
@@ -58,7 +58,7 @@ extern "C" {
 
     GGML_API void ggml_backend_graph_plan_free   (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
     GGML_API void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
-    GGML_API void ggml_backend_graph_compute     (ggml_backend_t backend, struct ggml_cgraph * cgraph);
+    GGML_API bool ggml_backend_graph_compute     (ggml_backend_t backend, struct ggml_cgraph * cgraph);
     GGML_API bool ggml_backend_supports_op       (ggml_backend_t backend, const struct ggml_tensor * op);
 
     // tensor copy between different backends
index 52d3cc6a6a67ca5c430da364ccf63e27919eb254..10c21615e6b71f57381d3559e98c29f3b0873a29 100644 (file)
@@ -9910,7 +9910,7 @@ static void ggml_backend_cuda_graph_plan_compute(ggml_backend_t backend, ggml_ba
     UNUSED(plan);
 }
 
-static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
+static bool ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
     ggml_backend_context_cuda * cuda_ctx = (ggml_backend_context_cuda *)backend->context;
 
     ggml_cuda_set_main_device(cuda_ctx->device);
@@ -9967,6 +9967,8 @@ static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph
     }
 
     UNUSED(backend);
+
+    return true;
 }
 
 static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
index b5e02b668a0f70790f4c76692f2241dd02951d1e..c4b7325da6187419b70ae55e5ff6e7e526eebfda 100644 (file)
@@ -87,7 +87,7 @@ int * ggml_metal_get_concur_list(struct ggml_metal_context * ctx);
 
 // same as ggml_graph_compute but uses Metal
 // creates gf->n_threads command buffers in parallel
-void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
+bool ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
 
 //
 // backend API
index 7aa92c14c9cdcc326a40c72afbfb1ab58ad36367..55cc1a872b21e711366f5638a82d142255a1d39b 100644 (file)
@@ -977,7 +977,7 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) {
             return false;
     }
 }
-void ggml_metal_graph_compute(
+bool ggml_metal_graph_compute(
         struct ggml_metal_context * ctx,
                struct ggml_cgraph * gf) {
     @autoreleasepool {
@@ -2405,10 +2405,11 @@ void ggml_metal_graph_compute(
         MTLCommandBufferStatus status = (MTLCommandBufferStatus) [ctx->command_buffers[i] status];
         if (status != MTLCommandBufferStatusCompleted) {
             GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
-            GGML_ASSERT(false);
+            return false;
         }
     }
 
+    return true;
     }
 }
 
@@ -2688,10 +2689,10 @@ static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggm
     UNUSED(backend);
 }
 
-static void ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+static bool ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
     struct ggml_metal_context * metal_ctx = (struct ggml_metal_context *)backend->context;
 
-    ggml_metal_graph_compute(metal_ctx, cgraph);
+    return ggml_metal_graph_compute(metal_ctx, cgraph);
 }
 
 static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
index 5c6441012bd18e23233c8beb152137b1c6ee5b85..4f216a984c9c145a8eff32401264b80ebcff2693 100644 (file)
@@ -152,7 +152,7 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
 // ggml helpers
 //
 
-static void ggml_graph_compute_helper(
+static bool ggml_graph_compute_helper(
           struct ggml_cgraph * graph,
         std::vector<uint8_t> & buf,
                          int   n_threads,
@@ -168,10 +168,10 @@ static void ggml_graph_compute_helper(
         plan.work_data = buf.data();
     }
 
-    ggml_graph_compute(graph, &plan);
+    return ggml_graph_compute(graph, &plan);
 }
 
-static void ggml_graph_compute_helper(
+static bool ggml_graph_compute_helper(
        struct ggml_backend * backend,
         struct ggml_cgraph * graph,
                        int   n_threads) {
@@ -183,7 +183,7 @@ static void ggml_graph_compute_helper(
         ggml_backend_metal_set_n_cb(backend, n_threads);
     }
 #endif
-    ggml_backend_graph_compute(backend, graph);
+    return ggml_backend_graph_compute(backend, graph);
 }
 
 // faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
@@ -2103,7 +2103,9 @@ static bool whisper_encode_internal(
         ggml_allocr_alloc_graph(alloc, gf);
 
         if (!whisper_encode_external(wstate)) {
-            ggml_graph_compute_helper(wstate.backend, gf, n_threads);
+            if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
+                return false;
+            }
         }
     }
 
@@ -2117,7 +2119,9 @@ static bool whisper_encode_internal(
 
         ggml_allocr_alloc_graph(alloc, gf);
 
-        ggml_graph_compute_helper(wstate.backend, gf, n_threads);
+        if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
+            return false;
+        }
     }
 
     // cross
@@ -2130,7 +2134,9 @@ static bool whisper_encode_internal(
 
         ggml_allocr_alloc_graph(alloc, gf);
 
-        ggml_graph_compute_helper(wstate.backend, gf, n_threads);
+        if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
+            return false;
+        }
     }
 
     wstate.t_encode_us += ggml_time_us() - t_start_us;
@@ -2552,7 +2558,9 @@ static bool whisper_decode_internal(
 
         logits = gf->nodes[gf->n_nodes - 1];
 
-        ggml_graph_compute_helper(wstate.backend, gf, n_threads);
+        if (!ggml_graph_compute_helper(wstate.backend, gf, n_threads)) {
+            return false;
+        }
     }
 
     logits_out.resize(n_tokens*n_vocab);