]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
Fix wait logic for inflight jobs (llama/20096)
authorNikhil Jain <redacted>
Wed, 4 Mar 2026 19:54:55 +0000 (11:54 -0800)
committerGeorgi Gerganov <redacted>
Mon, 16 Mar 2026 11:10:15 +0000 (13:10 +0200)
* Enable tmate debugging for investigating thread safety issue

* Refactor wait and submit to operate on vector<wgpu::FutureWaitInfo>, and fix wait to delete only the future that is completed.

* Cleanup

* Remove clear change and run clang-format

* Cleanup

ggml/src/ggml-webgpu/ggml-webgpu.cpp

index 334919e589fa648e16792c4554ba4e7568de1c85..b2ef2d59010588d7fd4d433bceaec7e2ad63748e 100644 (file)
@@ -123,11 +123,6 @@ struct webgpu_pool_bufs {
     wgpu::Buffer dev_buf;
 };
 
-// The futures to wait on for a single queue submission
-struct webgpu_submission_futures {
-    std::vector<wgpu::FutureWaitInfo> futures;
-};
-
 // Holds a pool of parameter buffers for WebGPU operations
 struct webgpu_buf_pool {
     std::vector<webgpu_pool_bufs> free;
@@ -463,26 +458,60 @@ static void ggml_webgpu_create_buffer(wgpu::Device &    device,
 /** End WebGPU object initializations */
 
 /** WebGPU Actions */
+static void erase_completed(std::vector<wgpu::FutureWaitInfo> & futures) {
+    futures.erase(std::remove_if(futures.begin(), futures.end(),
+                                 [](const wgpu::FutureWaitInfo & info) { return info.completed; }),
+                  futures.end());
+}
 
 // Wait for the queue to finish processing all submitted work
-static void ggml_backend_webgpu_wait(webgpu_global_context &                  ctx,
-                                     std::vector<webgpu_submission_futures> & futures,
-                                     bool                                     block = true) {
+static void ggml_backend_webgpu_wait(webgpu_global_context &             ctx,
+                                     std::vector<wgpu::FutureWaitInfo> & futures,
+                                     bool                                block = true) {
     // If we have too many in-flight submissions, wait on the oldest one first.
+    if (futures.empty()) {
+        return;
+    }
     uint64_t timeout_ms = block ? UINT64_MAX : 0;
     while (futures.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) {
-        ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX);
-        futures.erase(futures.begin());
+        auto waitStatus = ctx->instance.WaitAny(1, &futures[0], UINT64_MAX);
+        if (waitStatus == wgpu::WaitStatus::Error) {
+            GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
+        }
+        if (futures[0].completed) {
+            futures.erase(futures.begin());
+        }
+    }
+
+    if (futures.empty()) {
+        return;
     }
-    size_t i = 0;
-    while (i < futures.size()) {
-        auto waitStatus = ctx->instance.WaitAny(futures[i].futures.size(), futures[i].futures.data(), timeout_ms);
+
+    if (block) {
+        while (!futures.empty()) {
+            auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
+            switch (waitStatus) {
+                case wgpu::WaitStatus::Success:
+                    // WaitAny doesn't tell us which future completed, so we must check all futures to see which finished.
+                    erase_completed(futures);
+                    break;
+                case wgpu::WaitStatus::Error:
+                    GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
+                    break;
+                default:
+                    GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n");
+                    break;
+            }
+        }
+    } else {
+        // Poll once and return
+        auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms);
         switch (waitStatus) {
             case wgpu::WaitStatus::Success:
-                futures.erase(futures.begin() + i);
+                // WaitAny doesn't tell us which future completed, so we must check all futures to see which finished.
+                erase_completed(futures);
                 break;
             case wgpu::WaitStatus::TimedOut:
-                i++;
                 break;
             case wgpu::WaitStatus::Error:
                 GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n");
@@ -525,10 +554,11 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) {
 }
 #endif
 
-static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_global_context       ctx,
-                                                            std::vector<webgpu_command> commands,
-                                                            webgpu_buf_pool &           param_buf_pool,
-                                                            webgpu_buf_pool * set_rows_error_buf_pool = nullptr) {
+static std::vector<wgpu::FutureWaitInfo> ggml_backend_webgpu_submit(
+    webgpu_global_context       ctx,
+    std::vector<webgpu_command> commands,
+    webgpu_buf_pool &           param_buf_pool,
+    webgpu_buf_pool *           set_rows_error_buf_pool = nullptr) {
     std::vector<wgpu::CommandBuffer> command_buffers;
     std::vector<webgpu_pool_bufs>    params_bufs;
     std::vector<webgpu_pool_bufs>    set_rows_error_bufs;
@@ -600,7 +630,7 @@ static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_global_contex
         futures.push_back({ f });
     }
 #endif
-    return { futures };
+    return futures;
 }
 
 static webgpu_command ggml_backend_webgpu_build_multi(
@@ -727,8 +757,7 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx,
 
     webgpu_command command =
         ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x);
-    std::vector<webgpu_submission_futures> futures = { ggml_backend_webgpu_submit(ctx, { command },
-                                                                                  ctx->memset_buf_pool) };
+    auto futures = ggml_backend_webgpu_submit(ctx, { command }, ctx->memset_buf_pool);
     ggml_backend_webgpu_wait(ctx, futures);
 }
 
@@ -836,7 +865,7 @@ static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0
     binary_overlap_flags flags = {};
     flags.inplace              = ggml_webgpu_tensor_equal(src0, dst);
     flags.overlap              = ggml_webgpu_tensor_overlap(src1, dst);
-    flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1);
+    flags.src_overlap          = ggml_webgpu_tensor_overlap(src0, src1);
 
     return flags;
 }
@@ -1153,8 +1182,8 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
     };
 
     // Calculate workgroup dimensions
-    uint32_t       wg_x = 1;
-    uint32_t       wg_y = 1;
+    uint32_t       wg_x           = 1;
+    uint32_t       wg_y           = 1;
     const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension;
 
     if (use_fast && is_vec) {
@@ -1410,7 +1439,7 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
     uint32_t offset_merged_src0 = 0;
     uint32_t offset_merged_src1 = 0;
     if (flags.src_overlap) {
-        size_t min_off = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset);
+        size_t min_off     = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset);
         offset_merged_src0 = (uint32_t) ((src0_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type));
         offset_merged_src1 = (uint32_t) ((src1_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type));
     }
@@ -1419,7 +1448,7 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
         ne,
         (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
         (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
-        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst)  / ggml_type_size(dst->type)),
+        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
         offset_merged_src0,
         offset_merged_src1,
         (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
@@ -2185,9 +2214,9 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
 
     WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute);
 
-    std::vector<webgpu_command>            commands;
-    std::vector<webgpu_submission_futures> futures;
-    uint32_t                               num_batched_kernels = 0;
+    std::vector<webgpu_command>       commands;
+    std::vector<wgpu::FutureWaitInfo> futures;
+    uint32_t                          num_batched_kernels = 0;
     for (int i = 0; i < cgraph->n_nodes; i++) {
         if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) {
             commands.push_back(*cmd);
@@ -2195,9 +2224,10 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
         }
 
         if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
-            num_batched_kernels = 0;
-            futures.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool,
-                                                         &ctx->set_rows_error_buf_pool));
+            num_batched_kernels                               = 0;
+            std::vector<wgpu::FutureWaitInfo> compute_futures = ggml_backend_webgpu_submit(
+                ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool);
+            futures.insert(futures.end(), compute_futures.begin(), compute_futures.end());
             // Process events and check for completed submissions
             ctx->global_ctx->instance.ProcessEvents();
             ggml_backend_webgpu_wait(ctx->global_ctx, futures, false);
@@ -2205,9 +2235,9 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
         }
     }
     if (!commands.empty()) {
-        webgpu_submission_futures new_futures =
+        auto new_futures =
             ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool);
-        futures.push_back(new_futures);
+        futures.insert(futures.end(), new_futures.begin(), new_futures.end());
     }
 
     ggml_backend_webgpu_wait(ctx->global_ctx, futures);