]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml-webgpu: JIT compile binary operators and handle binding overlaps (llama/19310)
authorAbhijit Ramesh <redacted>
Fri, 6 Feb 2026 18:33:30 +0000 (10:33 -0800)
committerGeorgi Gerganov <redacted>
Sat, 7 Feb 2026 08:37:38 +0000 (10:37 +0200)
* ggml webgpu: port binary operators to use pre-wgsl

* Add binary.wgsl: unified shader with conditionals for all 4 ops

* Add gen_binary_shaders.cpp: build tool for using pre_wgsl preprocessor

* Remove bin_op.tmpl.wgsl and binary.wgsl (Python template)

* Update CMake to generate binary operator shaders at build time

* ggml-webgpu: migrate binary ops to JIT compilation with overlap handling

* port binary operators from AOT to pre-wgsl JIT compilation

* add src1=dst overlap handling for binary ops

* use compile-time workgroup size defines instead of runtime overrides

* ggml-webgpu: complete overlap handling for binary ops

* add support for inplace & overlap case in binding setup

* restructure conditional logic to handle all overlap cases

* ensure all buffer bindings are correctly assigned for edge cases

* ggml-webgpu: remove unused binary overlap cases

Remove src0==src1 binary overlap case that never occurs in practice.

* keep INPLACE (src0==dst), OVERLAP (src1==dst), DEFAULT

* remove unused src0==src1 and all-same variant

* refactor wgsl to eliminate duplication

src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
src/ggml-webgpu/ggml-webgpu.cpp
src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl [deleted file]
src/ggml-webgpu/wgsl-shaders/binary.wgsl [new file with mode: 0644]
src/ggml-webgpu/wgsl-shaders/binary_head.tmpl [deleted file]

index 84d88e81d45a4a13d52fc56f6331275d0bbe89e1..6997f6bdd31d97b414b105d31850c210d0975d78 100644 (file)
@@ -465,4 +465,73 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_unary_shader(
     return result;
 }
 
+/** Binary **/
+
+struct ggml_webgpu_binary_pipeline_key {
+    int  type;
+    int  op;
+    bool inplace;
+    bool overlap;
+
+    bool operator==(const ggml_webgpu_binary_pipeline_key & other) const {
+        return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap;
+    }
+};
+
+struct ggml_webgpu_binary_pipeline_key_hash {
+    size_t operator()(const ggml_webgpu_binary_pipeline_key & key) const {
+        size_t seed = 0;
+        ggml_webgpu_hash_combine(seed, key.type);
+        ggml_webgpu_hash_combine(seed, key.op);
+        ggml_webgpu_hash_combine(seed, key.inplace);
+        ggml_webgpu_hash_combine(seed, key.overlap);
+        return seed;
+    }
+};
+
+struct ggml_webgpu_binary_shader_lib_context {
+    ggml_webgpu_binary_pipeline_key key;
+    uint32_t                        max_wg_size;
+};
+
+inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_binary_shader(
+    pre_wgsl::Preprocessor &                      preprocessor,
+    const char *                                  shader_src,
+    const ggml_webgpu_binary_shader_lib_context & context) {
+    std::vector<std::string> defines;
+    std::string              op_name = ggml_op_name((ggml_op) context.key.op);
+    std::string              variant = op_name;
+
+    defines.push_back(std::string("OP_") + op_name);
+
+    switch (context.key.type) {
+        case GGML_TYPE_F32:
+            defines.push_back("TYPE_F32");
+            variant += "_f32";
+            break;
+        case GGML_TYPE_F16:
+            defines.push_back("TYPE_F16");
+            variant += "_f16";
+            break;
+        default:
+            GGML_ABORT("Unsupported type for binary shader");
+    }
+
+    if (context.key.inplace) {
+        defines.push_back("INPLACE");
+        variant += "_inplace";
+    } else if (context.key.overlap) {
+        defines.push_back("OVERLAP");
+        variant += "_overlap";
+    }
+
+    defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+    ggml_webgpu_processed_shader result;
+    result.wgsl                                      = preprocessor.preprocess(shader_src, defines);
+    result.variant                                   = variant;
+    ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions();
+    decisions->wg_size                               = context.max_wg_size;
+    result.decisions                                 = decisions;
+    return result;
+}
 #endif  // GGML_WEBGPU_SHADER_LIB_HPP
index 4ef50e365efdefd4cde190e3bda27205ab6b404b..f7ceca11212a0056bb4891ff0c231b3c2c71381b 100644 (file)
@@ -348,13 +348,12 @@ struct webgpu_context_struct {
 
     std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
                                                   set_rows_pipelines;
-    std::map<int, std::map<int, webgpu_pipeline>> get_rows_pipelines;                 // src_type, vectorized
+    std::map<int, std::map<int, webgpu_pipeline>> get_rows_pipelines;  // src_type, vectorized
 
-    std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines;                      // src_type, dst_type
-    std::map<int, std::map<int, webgpu_pipeline>> add_pipelines;                      // type, inplace
-    std::map<int, std::map<int, webgpu_pipeline>> sub_pipelines;                      // type, inplace
-    std::map<int, std::map<int, webgpu_pipeline>> mul_pipelines;                      // type, inplace
-    std::map<int, std::map<int, webgpu_pipeline>> div_pipelines;                      // type, inplace
+    std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines;       // src_type, dst_type
+
+    std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>
+        binary_pipelines;
 
     std::map<int, webgpu_pipeline>                               rms_norm_pipelines;  // inplace
     std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines;      // type, ff, inplace
@@ -823,6 +822,28 @@ static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) {
            (ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b));
 }
 
+// Used to determine if two tensors share the same buffer and their byte ranges overlap,
+static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) {
+    return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&
+           ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) &&
+           ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a));
+}
+
+struct binary_overlap_flags {
+    bool inplace;  // src0 == dst
+    bool overlap;  // src1 == dst
+};
+
+static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0,
+                                                              ggml_tensor * src1,
+                                                              ggml_tensor * dst) {
+    binary_overlap_flags flags = {};
+    flags.inplace              = ggml_webgpu_tensor_equal(src0, dst);
+    flags.overlap              = ggml_webgpu_tensor_overlap(src1, dst);
+
+    return flags;
+}
+
 static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
     uint32_t ne = (uint32_t) ggml_nelements(dst);
 
@@ -1375,14 +1396,42 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s
     return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
 }
 
-static webgpu_command ggml_webgpu_binary_op(webgpu_context &  ctx,
-                                            ggml_tensor *     src0,
-                                            ggml_tensor *     src1,
-                                            ggml_tensor *     dst,
-                                            webgpu_pipeline & pipeline,
-                                            bool              inplace) {
+static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
+                                            ggml_tensor *    src0,
+                                            ggml_tensor *    src1,
+                                            ggml_tensor *    dst) {
+    binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst);
+
+    ggml_webgpu_binary_pipeline_key pipeline_key = {
+        .type    = dst->type,
+        .op      = dst->op,
+        .inplace = flags.inplace,
+        .overlap = flags.overlap,
+    };
+    ggml_webgpu_binary_shader_lib_context shader_lib_ctx = {
+        .key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
+    };
+
+    webgpu_pipeline pipeline;
+    auto            it = ctx->binary_pipelines.find(pipeline_key);
+    if (it != ctx->binary_pipelines.end()) {
+        pipeline = it->second;
+    } else {
+        ggml_webgpu_processed_shader processed =
+            ggml_webgpu_preprocess_binary_shader(ctx->p, wgsl_binary, shader_lib_ctx);
+        pipeline =
+            ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
+        pipeline.context = processed.decisions;
+        ctx->binary_pipelines.emplace(pipeline_key, pipeline);
+    }
+
+    ggml_webgpu_generic_shader_decisions decisions =
+        *static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context);
+
+    uint32_t ne = (uint32_t) ggml_nelements(dst);
+
     std::vector<uint32_t> params = {
-        (uint32_t) ggml_nelements(dst),
+        ne,
         (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
         (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
         (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
@@ -1399,24 +1448,30 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context &  ctx,
         (uint32_t) src1->ne[3],
     };
 
-    std::vector<wgpu::BindGroupEntry> entries = {
-        { .binding = 0,
-         .buffer  = ggml_webgpu_tensor_buf(src0),
-         .offset  = ggml_webgpu_tensor_align_offset(ctx, src0),
-         .size    = ggml_webgpu_tensor_binding_size(ctx, src0) },
-        { .binding = 1,
-         .buffer  = ggml_webgpu_tensor_buf(src1),
-         .offset  = ggml_webgpu_tensor_align_offset(ctx, src1),
-         .size    = ggml_webgpu_tensor_binding_size(ctx, src1) }
-    };
-    if (!inplace) {
+    std::vector<wgpu::BindGroupEntry> entries;
+
+    entries.push_back({
+        .binding = 0,
+        .buffer  = ggml_webgpu_tensor_buf(src0),
+        .offset  = ggml_webgpu_tensor_align_offset(ctx, src0),
+        .size    = ggml_webgpu_tensor_binding_size(ctx, src0),
+    });
+
+    entries.push_back({
+        .binding = 1,
+        .buffer  = ggml_webgpu_tensor_buf(src1),
+        .offset  = ggml_webgpu_tensor_align_offset(ctx, src1),
+        .size    = ggml_webgpu_tensor_binding_size(ctx, src1),
+    });
+
+    if (!flags.inplace && !flags.overlap) {
         entries.push_back({ .binding = 2,
                             .buffer  = ggml_webgpu_tensor_buf(dst),
                             .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
                             .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
     }
 
-    uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
+    uint32_t wg_x = CEIL_DIV(ne, decisions.wg_size);
     return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
 }
 
@@ -2038,25 +2093,10 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
             return std::nullopt;
 #endif
         case GGML_OP_ADD:
-            {
-                int inplace = ggml_webgpu_tensor_equal(src0, node);
-                return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipelines[node->type][inplace], inplace);
-            }
         case GGML_OP_SUB:
-            {
-                int inplace = ggml_webgpu_tensor_equal(src0, node);
-                return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipelines[node->type][inplace], inplace);
-            }
         case GGML_OP_MUL:
-            {
-                int inplace = ggml_webgpu_tensor_equal(src0, node);
-                return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipelines[node->type][inplace], inplace);
-            }
         case GGML_OP_DIV:
-            {
-                int inplace = ggml_webgpu_tensor_equal(src0, node);
-                return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipelines[node->type][inplace], inplace);
-            }
+            return ggml_webgpu_binary_op(ctx, src0, src1, node);
         case GGML_OP_RMS_NORM:
             return ggml_webgpu_rms_norm(ctx, src0, node);
         case GGML_OP_ROPE:
@@ -2665,58 +2705,6 @@ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
         ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
 }
 
-static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
-    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
-
-    webgpu_ctx->add_pipelines[GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f32, "add_f32", constants);
-    webgpu_ctx->add_pipelines[GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f16, "add_f16", constants);
-    webgpu_ctx->add_pipelines[GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f32_inplace, "add_f32_inplace", constants);
-    webgpu_ctx->add_pipelines[GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f16_inplace, "add_f16_inplace", constants);
-}
-
-static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) {
-    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
-
-    webgpu_ctx->sub_pipelines[GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f32, "sub_f32", constants);
-    webgpu_ctx->sub_pipelines[GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f16, "sub_f16", constants);
-    webgpu_ctx->sub_pipelines[GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f32_inplace, "sub_f32_inplace", constants);
-    webgpu_ctx->sub_pipelines[GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f16_inplace, "sub_f16_inplace", constants);
-}
-
-static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
-    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
-
-    webgpu_ctx->mul_pipelines[GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f32, "mul_f32", constants);
-    webgpu_ctx->mul_pipelines[GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f16, "mul_f16", constants);
-    webgpu_ctx->mul_pipelines[GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f32_inplace, "mul_f32_inplace", constants);
-    webgpu_ctx->mul_pipelines[GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f16_inplace, "mul_f16_inplace", constants);
-}
-
-static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) {
-    std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
-
-    webgpu_ctx->div_pipelines[GGML_TYPE_F32][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f32, "div_f32", constants);
-    webgpu_ctx->div_pipelines[GGML_TYPE_F16][0] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f16, "div_f16", constants);
-    webgpu_ctx->div_pipelines[GGML_TYPE_F32][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f32_inplace, "div_f32_inplace", constants);
-    webgpu_ctx->div_pipelines[GGML_TYPE_F16][1] =
-        ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f16_inplace, "div_f16_inplace", constants);
-}
-
 static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
     std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
 
@@ -3018,10 +3006,6 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
     ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx);
     ggml_webgpu_init_get_rows_pipeline(webgpu_ctx);
     ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
-    ggml_webgpu_init_add_pipeline(webgpu_ctx);
-    ggml_webgpu_init_sub_pipeline(webgpu_ctx);
-    ggml_webgpu_init_mul_pipeline(webgpu_ctx);
-    ggml_webgpu_init_div_pipeline(webgpu_ctx);
     ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx);
     ggml_webgpu_init_rope_pipeline(webgpu_ctx);
     ggml_webgpu_init_glu_pipeline(webgpu_ctx);
diff --git a/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl b/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl
deleted file mode 100644 (file)
index 1ce4d83..0000000
+++ /dev/null
@@ -1,188 +0,0 @@
-#define(VARIANTS)
-
-[
-  {
-    "SHADER_NAME": "add_f32",
-    "REPLS": {
-      "TYPE" : "f32",
-      "OP": "+"
-    },
-    "DECLS": ["NOT_INPLACE"]
-  },
-  {
-    "SHADER_NAME": "add_f16",
-    "REPLS": {
-      "TYPE" : "f16",
-      "OP": "+"
-    },
-    "DECLS": ["NOT_INPLACE"]
-  },
-  {
-    "SHADER_NAME": "add_f32_inplace",
-    "REPLS": {
-      "TYPE" : "f32",
-      "OP": "+"
-    },
-    "DECLS": ["INPLACE"]
-  },
-  {
-    "SHADER_NAME": "add_f16_inplace",
-    "REPLS": {
-      "TYPE" : "f16",
-      "OP": "+"
-    },
-    "DECLS": ["INPLACE"]
-  },
-  {
-    "SHADER_NAME": "mul_f32",
-    "REPLS": {
-      "TYPE" : "f32",
-      "OP": "*"
-    },
-    "DECLS": ["NOT_INPLACE"]
-  },
-  {
-    "SHADER_NAME": "mul_f16",
-    "REPLS": {
-      "TYPE" : "f16",
-      "OP": "*"
-    },
-    "DECLS": ["NOT_INPLACE"]
-  },
-  {
-    "SHADER_NAME": "mul_f32_inplace",
-    "REPLS": {
-      "TYPE" : "f32",
-      "OP": "*"
-    },
-    "DECLS": ["INPLACE"]
-  },
-  {
-    "SHADER_NAME": "mul_f16_inplace",
-    "REPLS": {
-      "TYPE" : "f16",
-      "OP": "*"
-    },
-    "DECLS": ["INPLACE"]
-  },
-  {
-    "SHADER_NAME": "sub_f32",
-    "REPLS": {
-      "TYPE" : "f32",
-      "OP": "-"
-    },
-    "DECLS": ["NOT_INPLACE"]
-  },
-  {
-    "SHADER_NAME": "sub_f16",
-    "REPLS": {
-      "TYPE" : "f16",
-      "OP": "-"
-    },
-    "DECLS": ["NOT_INPLACE"]
-  },
-  {
-    "SHADER_NAME": "sub_f32_inplace",
-    "REPLS": {
-      "TYPE" : "f32",
-      "OP": "-"
-    },
-    "DECLS": ["INPLACE"]
-  },
-  {
-    "SHADER_NAME": "sub_f16_inplace",
-    "REPLS": {
-      "TYPE" : "f16",
-      "OP": "-"
-    },
-    "DECLS": ["INPLACE"]
-  },
-  {
-    "SHADER_NAME": "div_f32",
-    "REPLS": {
-      "TYPE" : "f32",
-      "OP": "/"
-    },
-    "DECLS": ["NOT_INPLACE"]
-  },
-  {
-    "SHADER_NAME": "div_f16",
-    "REPLS": {
-      "TYPE" : "f16",
-      "OP": "/"
-    },
-    "DECLS": ["NOT_INPLACE"]
-  },
-  {
-    "SHADER_NAME": "div_f32_inplace",
-    "REPLS": {
-      "TYPE" : "f32",
-      "OP": "/"
-    },
-    "DECLS": ["INPLACE"]
-  },
-  {
-    "SHADER_NAME": "div_f16_inplace",
-    "REPLS": {
-      "TYPE" : "f16",
-      "OP": "/"
-    },
-    "DECLS": ["INPLACE"]
-  }
-]
-
-#end(VARIANTS)
-
-#define(DECLS)
-
-#decl(NOT_INPLACE)
-
-fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
-    dst[dst_i] = src0[src0_i] {{OP}} src1[src1_i];
-}
-
-@group(0) @binding(2)
-var<storage, read_write> dst: array<{{TYPE}}>;
-
-@group(0) @binding(3)
-var<uniform> params: Params;
-
-#enddecl(NOT_INPLACE)
-
-#decl(INPLACE)
-
-fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
-    src0[dst_i] = src0[src0_i] {{OP}} src1[src1_i];
-}
-
-@group(0) @binding(2)
-var<uniform> params: Params;
-
-#enddecl(INPLACE)
-
-#end(DECLS)
-
-
-#define(SHADER)
-
-enable f16;
-
-#include "binary_head.tmpl"
-
-@group(0) @binding(0)
-var<storage, read_write> src0: array<{{TYPE}}>;
-
-@group(0) @binding(1)
-var<storage, read_write> src1: array<{{TYPE}}>;
-
-DECLS
-
-override wg_size: u32;
-@compute @workgroup_size(wg_size)
-fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
-    if (gid.x < params.ne) {
-        update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x));
-    }
-}
-
-#end(SHADER)
diff --git a/src/ggml-webgpu/wgsl-shaders/binary.wgsl b/src/ggml-webgpu/wgsl-shaders/binary.wgsl
new file mode 100644 (file)
index 0000000..55dd664
--- /dev/null
@@ -0,0 +1,107 @@
+enable f16;
+
+struct Params {
+    ne: u32,
+
+    // offsets in elements
+    offset_src0: u32,
+    offset_src1: u32,
+    offset_dst: u32,
+
+    stride_src1_0: u32,
+    stride_src1_1: u32,
+    stride_src1_2: u32,
+    stride_src1_3: u32,
+
+    a_ne0: u32,
+    a_ne1: u32,
+    a_ne2: u32,
+
+    b_ne0: u32,
+    b_ne1: u32,
+    b_ne2: u32,
+    b_ne3: u32,
+};
+
+fn src1_index(_i: u32) -> u32 {
+    var i = _i;
+    let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0);
+    i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0);
+    let a_i2 = i / (params.a_ne1 * params.a_ne0);
+    i = i % (params.a_ne1 * params.a_ne0);
+    let a_i1 = i / params.a_ne0;
+    let a_i0 = i % params.a_ne0;
+
+    // handle repetition of b
+    // index loops back to the beginning and repeats after elements are exhausted = modulo
+    let b_i0 = a_i0 % params.b_ne0;
+    let b_i1 = a_i1 % params.b_ne1;
+    let b_i2 = a_i2 % params.b_ne2;
+    let b_i3 = a_i3 % params.b_ne3;
+
+    // compute index for position in b's flat array
+    return b_i0 * params.stride_src1_0 +
+           b_i1 * params.stride_src1_1 +
+           b_i2 * params.stride_src1_2 +
+           b_i3 * params.stride_src1_3;
+}
+
+#ifdef TYPE_F32
+#define DataType f32
+#endif
+#ifdef TYPE_F16
+#define DataType f16
+#endif
+
+@group(0) @binding(0)
+var<storage, read_write> src0: array<DataType>;
+
+@group(0) @binding(1)
+var<storage, read_write> src1 : array<DataType>;
+
+#ifdef INPLACE
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+#elif defined(OVERLAP)
+@group(0) @binding(2)
+var<uniform> params: Params;
+
+#else
+@group(0) @binding(2)
+var<storage, read_write> dst: array<DataType>;
+
+@group(0) @binding(3)
+var<uniform> params: Params;
+#endif
+
+fn op(a: DataType, b: DataType) -> DataType {
+#ifdef OP_ADD
+    return a + b;
+#elif defined(OP_SUB)
+    return a - b;
+#elif defined(OP_MUL)
+    return a * b;
+#elif defined(OP_DIV)
+    return a / b;
+#endif
+}
+
+fn update(dst_i: u32, src0_i: u32, src1_i: u32){
+    let result = op(src0[src0_i], src1[src1_i]);
+
+#ifdef INPLACE
+    src0[dst_i] = result;
+#elif defined(OVERLAP)
+    src1[dst_i] = result;
+#else
+    dst[dst_i] = result;
+#endif
+}
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
+    if (gid.x < params.ne) {
+        update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x));
+    }
+}
diff --git a/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl b/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl
deleted file mode 100644 (file)
index 4b254f4..0000000
+++ /dev/null
@@ -1,45 +0,0 @@
-struct Params {
-    ne: u32,
-
-    // offsets in elements
-    offset_src0: u32,
-    offset_src1: u32,
-    offset_dst: u32,
-
-    stride_src1_0: u32,
-    stride_src1_1: u32,
-    stride_src1_2: u32,
-    stride_src1_3: u32,
-
-    a_ne0: u32,
-    a_ne1: u32,
-    a_ne2: u32,
-
-    b_ne0: u32,
-    b_ne1: u32,
-    b_ne2: u32,
-    b_ne3: u32,
-};
-
-fn src1_index(_i: u32) -> u32 {
-    var i = _i;
-    let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0);
-    i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0);
-    let a_i2 = i / (params.a_ne1 * params.a_ne0);
-    i = i % (params.a_ne1 * params.a_ne0);
-    let a_i1 = i / params.a_ne0;
-    let a_i0 = i % params.a_ne0;
-
-    // handle repetition of b
-    // index loops back to the beginning and repeats after elements are exhausted = modulo
-    let b_i0 = a_i0 % params.b_ne0;
-    let b_i1 = a_i1 % params.b_ne1;
-    let b_i2 = a_i2 % params.b_ne2;
-    let b_i3 = a_i3 % params.b_ne3;
-
-    // compute index for position in b's flat array
-    return b_i0 * params.stride_src1_0 +
-           b_i1 * params.stride_src1_1 +
-           b_i2 * params.stride_src1_2 +
-           b_i3 * params.stride_src1_3;
-}