]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml-webgpu: Support non-contiguous `src0` and overlapping `src0/src1` in binary...
authorMasashi Yoshimura <redacted>
Mon, 2 Mar 2026 15:59:53 +0000 (00:59 +0900)
committerGeorgi Gerganov <redacted>
Mon, 16 Mar 2026 11:10:15 +0000 (13:10 +0200)
* ggml-webgpu: Add binary op support for overlapping and non-contiguous.

* Add newline to binary.wgsl

* Append the test of binary op for src overlapping  to test_bin_bcast.

* Remove unnecessary newline.

ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
ggml/src/ggml-webgpu/ggml-webgpu.cpp
ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl

index 0d5a818dacb598e5e29947f59b8f0a9fe1c3929b..369475eaf50ca2916b769df25fccea38d8ee6710 100644 (file)
@@ -68,6 +68,7 @@ struct ggml_webgpu_shader_lib_context {
     size_t   wg_mem_limit_bytes       = 0;
     bool     inplace                  = false;
     bool     overlap                  = false;
+    bool     src_overlap              = false;
     bool     supports_subgroup_matrix = false;
     uint32_t sg_mat_m                 = 0;
     uint32_t sg_mat_n                 = 0;
@@ -179,9 +180,10 @@ struct ggml_webgpu_binary_pipeline_key {
     int  op;
     bool inplace;
     bool overlap;
+    bool src_overlap;
 
     bool operator==(const ggml_webgpu_binary_pipeline_key & other) const {
-        return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap;
+        return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap && src_overlap == other.src_overlap;
     }
 };
 
@@ -192,6 +194,7 @@ struct ggml_webgpu_binary_pipeline_key_hash {
         ggml_webgpu_hash_combine(seed, key.op);
         ggml_webgpu_hash_combine(seed, key.inplace);
         ggml_webgpu_hash_combine(seed, key.overlap);
+        ggml_webgpu_hash_combine(seed, key.src_overlap);
         return seed;
     }
 };
@@ -1044,6 +1047,7 @@ class ggml_webgpu_shader_lib {
             .op      = context.dst->op,
             .inplace = context.inplace,
             .overlap = context.overlap,
+            .src_overlap = context.src_overlap,
         };
 
         auto it = binary_pipelines.find(key);
@@ -1076,6 +1080,9 @@ class ggml_webgpu_shader_lib {
         } else if (key.overlap) {
             defines.push_back("OVERLAP");
             variant += "_overlap";
+        } else if (key.src_overlap) {
+            defines.push_back("SRC_OVERLAP");
+            variant += "_src_overlap";
         }
 
         defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
index 1c00d3cb2b16c9cff0da95c169c622d53f8b8a84..4dc56e1dc586e7c5e8f549d96af3a55518f0c7ae 100644 (file)
@@ -788,6 +788,7 @@ static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) {
 struct binary_overlap_flags {
     bool inplace;  // src0 == dst
     bool overlap;  // src1 == dst
+    bool src_overlap;
 };
 
 static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0,
@@ -796,6 +797,7 @@ static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0
     binary_overlap_flags flags = {};
     flags.inplace              = ggml_webgpu_tensor_equal(src0, dst);
     flags.overlap              = ggml_webgpu_tensor_overlap(src1, dst);
+    flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1);
 
     return flags;
 }
@@ -1353,6 +1355,7 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
         .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
         .inplace     = flags.inplace,
         .overlap     = flags.overlap,
+        .src_overlap = flags.src_overlap,
     };
 
     webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx);
@@ -1361,11 +1364,28 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
 
     uint32_t ne = (uint32_t) ggml_nelements(dst);
 
+    size_t src0_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src0);
+    size_t src1_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src1);
+
+    uint32_t offset_merged_src0 = 0;
+    uint32_t offset_merged_src1 = 0;
+    if (flags.src_overlap) {
+        size_t min_off = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset);
+        offset_merged_src0 = (uint32_t) ((src0_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type));
+        offset_merged_src1 = (uint32_t) ((src1_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type));
+    }
+
     std::vector<uint32_t> params = {
         ne,
         (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
         (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
-        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst)  / ggml_type_size(dst->type)),
+        offset_merged_src0,
+        offset_merged_src1,
+        (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
+        (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
+        (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
+        (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
         (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
         (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
         (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
@@ -1381,25 +1401,43 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
 
     std::vector<wgpu::BindGroupEntry> entries;
 
-    entries.push_back({
-        .binding = 0,
-        .buffer  = ggml_webgpu_tensor_buf(src0),
-        .offset  = ggml_webgpu_tensor_align_offset(ctx, src0),
-        .size    = ggml_webgpu_tensor_binding_size(ctx, src0),
-    });
-
-    entries.push_back({
-        .binding = 1,
-        .buffer  = ggml_webgpu_tensor_buf(src1),
-        .offset  = ggml_webgpu_tensor_align_offset(ctx, src1),
-        .size    = ggml_webgpu_tensor_binding_size(ctx, src1),
-    });
-
-    if (!flags.inplace && !flags.overlap) {
-        entries.push_back({ .binding = 2,
-                            .buffer  = ggml_webgpu_tensor_buf(dst),
-                            .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
-                            .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
+    if (flags.src_overlap) {
+        size_t merged_offset = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset);
+        size_t merged_end    = std::max(src0_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src0),
+                                        src1_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src1));
+        entries.push_back({
+            .binding = 0,
+            .buffer  = ggml_webgpu_tensor_buf(src0),
+            .offset  = merged_offset,
+            .size    = merged_end - merged_offset,
+        });
+        entries.push_back({
+            .binding = 1,
+            .buffer  = ggml_webgpu_tensor_buf(dst),
+            .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
+            .size    = ggml_webgpu_tensor_binding_size(ctx, dst),
+        });
+    } else {
+        entries.push_back({
+            .binding = 0,
+            .buffer  = ggml_webgpu_tensor_buf(src0),
+            .offset  = src0_webgpu_tensor_align_offset,
+            .size    = ggml_webgpu_tensor_binding_size(ctx, src0),
+        });
+        entries.push_back({
+            .binding = 1,
+            .buffer  = ggml_webgpu_tensor_buf(src1),
+            .offset  = src1_webgpu_tensor_align_offset,
+            .size    = ggml_webgpu_tensor_binding_size(ctx, src1),
+        });
+        if (!flags.inplace && !flags.overlap) {
+            entries.push_back({
+                .binding = 2,
+                .buffer  = ggml_webgpu_tensor_buf(dst),
+                .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
+                .size    = ggml_webgpu_tensor_binding_size(ctx, dst),
+            });
+        }
     }
 
     uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
@@ -2816,10 +2854,8 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
         case GGML_OP_SUB:
         case GGML_OP_MUL:
         case GGML_OP_DIV:
-            // TODO: support non-contiguous tensors, e.g. for MOE_EXPERT_REDUCE
-            // see https://github.com/ggml-org/llama.cpp/pull/16857
             supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) &&
-                          (src1->type == op->type) && ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
+                          (src1->type == op->type);
             break;
         case GGML_OP_CPY:
         case GGML_OP_CONT:
index 55dd66408a3e34671494b6759a8954b33fb44ac2..a748dc1b86c8644c606fc2bf2add88a151f935d7 100644 (file)
@@ -7,6 +7,13 @@ struct Params {
     offset_src0: u32,
     offset_src1: u32,
     offset_dst: u32,
+    offset_merged_src0: u32,
+    offset_merged_src1: u32,
+
+    stride_src0_0: u32,
+    stride_src0_1: u32,
+    stride_src0_2: u32,
+    stride_src0_3: u32,
 
     stride_src1_0: u32,
     stride_src1_1: u32,
@@ -23,6 +30,21 @@ struct Params {
     b_ne3: u32,
 };
 
+fn src0_index(_i: u32) -> u32 {
+    var i = _i;
+    let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0);
+    i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0);
+    let a_i2 = i / (params.a_ne1 * params.a_ne0);
+    i = i % (params.a_ne1 * params.a_ne0);
+    let a_i1 = i / params.a_ne0;
+    let a_i0 = i % params.a_ne0;
+
+    return a_i0 * params.stride_src0_0 +
+           a_i1 * params.stride_src0_1 +
+           a_i2 * params.stride_src0_2 +
+           a_i3 * params.stride_src0_3;
+}
+
 fn src1_index(_i: u32) -> u32 {
     var i = _i;
     let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0);
@@ -53,17 +75,22 @@ fn src1_index(_i: u32) -> u32 {
 #define DataType f16
 #endif
 
+#ifdef SRC_OVERLAP
 @group(0) @binding(0)
-var<storage, read_write> src0: array<DataType>;
+var<storage, read_write> merged_src: array<DataType>;
 
 @group(0) @binding(1)
-var<storage, read_write> src1 : array<DataType>;
+var<storage, read_write> dst: array<DataType>;
 
-#ifdef INPLACE
 @group(0) @binding(2)
 var<uniform> params: Params;
+#else
+@group(0) @binding(0)
+var<storage, read_write> src0: array<DataType>;
 
-#elif defined(OVERLAP)
+@group(0) @binding(1)
+var<storage, read_write> src1 : array<DataType>;
+#if defined(INPLACE) || defined(OVERLAP)
 @group(0) @binding(2)
 var<uniform> params: Params;
 
@@ -74,6 +101,7 @@ var<storage, read_write> dst: array<DataType>;
 @group(0) @binding(3)
 var<uniform> params: Params;
 #endif
+#endif
 
 fn op(a: DataType, b: DataType) -> DataType {
 #ifdef OP_ADD
@@ -87,13 +115,17 @@ fn op(a: DataType, b: DataType) -> DataType {
 #endif
 }
 
-fn update(dst_i: u32, src0_i: u32, src1_i: u32){
+fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
+#ifdef SRC_OVERLAP
+    let result = op(merged_src[src0_i], merged_src[src1_i]);
+#else
     let result = op(src0[src0_i], src1[src1_i]);
+#endif
 
 #ifdef INPLACE
-    src0[dst_i] = result;
+    src0[src0_i] = result;
 #elif defined(OVERLAP)
-    src1[dst_i] = result;
+    src1[src1_i] = result;
 #else
     dst[dst_i] = result;
 #endif
@@ -102,6 +134,8 @@ fn update(dst_i: u32, src0_i: u32, src1_i: u32){
 @compute @workgroup_size(WG_SIZE)
 fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
     if (gid.x < params.ne) {
-        update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x));
+        let src0_i = params.offset_src0 + params.offset_merged_src0 + src0_index(gid.x);
+        let src1_i = params.offset_src1 + params.offset_merged_src1 + src1_index(gid.x);
+        update(params.offset_dst + gid.x, src0_i, src1_i);
     }
 }