]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml webgpu: actually add softmax, fix rms_norm offset (llama/16400)
authorReese Levine <redacted>
Sun, 5 Oct 2025 03:59:31 +0000 (20:59 -0700)
committerGeorgi Gerganov <redacted>
Sun, 12 Oct 2025 08:16:23 +0000 (11:16 +0300)
* implement soft_max

* Fix soft_max data race

* Temporary fix, wait on each submit

ggml/src/ggml-webgpu/ggml-webgpu.cpp
ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl
ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl

index de68c5689bba730be14aa6a83e4f070a40db5b36..e795ca3fd92fd7d08f15dc5320abfb33bc038cc5 100644 (file)
@@ -424,6 +424,7 @@ static void ggml_backend_webgpu_build_and_enqueue(webgpu_context &
         ctx->staged_param_bufs.push_back(params_bufs);
         if (ctx->staged_command_bufs.size() == WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) {
             ggml_backend_webgpu_submit_queue(ctx);
+            ggml_backend_webgpu_wait_on_submission(ctx);
         }
     }
 }
@@ -1060,6 +1061,9 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
         case GGML_OP_SCALE:
             ggml_webgpu_scale(ctx, src0, node);
             break;
+        case GGML_OP_SOFT_MAX:
+            ggml_webgpu_soft_max(ctx, src0, src1, src2, node);
+            break;
         default:
             return false;
     }
@@ -1806,6 +1810,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
         case GGML_OP_SCALE:
             supports_op = op->type == GGML_TYPE_F32;
             break;
+        case GGML_OP_SOFT_MAX:
+            supports_op = op->type == GGML_TYPE_F32;
+            break;
         default:
             break;
     }
@@ -1949,6 +1956,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
     ggml_webgpu_init_rope_pipeline(ctx);
     ggml_webgpu_init_glu_pipeline(ctx);
     ggml_webgpu_init_scale_pipeline(ctx);
+    ggml_webgpu_init_soft_max_pipeline(ctx);
 
 #ifdef GGML_WEBGPU_DEBUG
     // Initialize debug buffers
index 4f72bb1c851ec33d622ee1bcc81cc22236d5369a..712b921f1abb99b8f44ade6e69a626af80951420 100644 (file)
@@ -84,7 +84,7 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
     let i2 = i / params.ne1;
     let i1 = i % params.ne1;
     let i_src_row = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1;
-    let i_dst_row = params.offset_src + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
+    let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1;
 
     let elems = (params.ne0 + wg_size - 1) / wg_size;
 
index 64ab576c083542ea17e7cadf0cfbc4a5c84b85c9..c74dc4cc9238ab8abf1135bc8bdc28fe333645d7 100644 (file)
@@ -300,6 +300,7 @@ fn main(@builtin(workgroup_id) wid: vec3<u32>,
         workgroupBarrier();
     }
     let row_max = scratch[0];
+    workgroupBarrier();
 
     var sum = 0.0f;
     col = lid.x;