]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Add concat op to webgpu. (#20068)
authorMasashi Yoshimura <redacted>
Wed, 4 Mar 2026 19:19:00 +0000 (04:19 +0900)
committerGitHub <redacted>
Wed, 4 Mar 2026 19:19:00 +0000 (11:19 -0800)
docs/ops.md
docs/ops/WebGPU.csv
ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
ggml/src/ggml-webgpu/ggml-webgpu.cpp
ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl [new file with mode: 0644]

index 296c0ba1d456b1d19ddec64a0bc125e0b582bae3..8213bc6abfb5b277bc5825f6649c0d4e0aa34642 100644 (file)
@@ -24,7 +24,7 @@ Legend:
 |                          ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
 |                             CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ |
 |                            CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
-|                           CONCAT | â\9d\8c | â\9c\85 | â\9c\85 | ð\9f\9f¡ | â\9c\85 | ð\9f\9f¡ | â\9c\85 | â\9c\85 | â\9d\8c | ❌ | ❌ |
+|                           CONCAT | â\9d\8c | â\9c\85 | â\9c\85 | ð\9f\9f¡ | â\9c\85 | ð\9f\9f¡ | â\9c\85 | â\9c\85 | â\9c\85 | ❌ | ❌ |
 |                             CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
 |                          CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ |
 |                       CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
index e2ed3e2cfad7eb15f9cc2cc3dc37c7bec50f83ef..9e081e7605fc85c0fc66fcb20921ad9a2d8939a1 100644 (file)
 "WebGPU: WebGPU","ROPE","type=f16,ne_a=[128,32,2,1],n_dims=128,mode=40,n_ctx=512,fs=1.424500,ef=0.746500,af=1.424500,ff=1,v=0,inplace=1","support","1","yes","WebGPU"
 "WebGPU: WebGPU","ROPE","type=f16,ne_a=[128,32,2,1],n_dims=128,mode=24,n_ctx=512,fs=1.424500,ef=0.746500,af=1.424500,ff=0,v=0,inplace=1","support","1","yes","WebGPU"
 "WebGPU: WebGPU","ROPE","type=f16,ne_a=[128,32,2,1],n_dims=128,mode=24,n_ctx=512,fs=1.424500,ef=0.746500,af=1.424500,ff=1,v=0,inplace=1","support","1","yes","WebGPU"
-"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=0","support","0","no","WebGPU"
-"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=0","support","0","no","WebGPU"
-"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=0","support","0","no","WebGPU"
-"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=0","support","0","no","WebGPU"
-"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=0","support","0","no","WebGPU"
-"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=0","support","0","no","WebGPU"
-"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=0","support","0","no","WebGPU"
-"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=0","support","0","no","WebGPU"
-"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=1","support","0","no","WebGPU"
-"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=1","support","0","no","WebGPU"
-"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=1","support","0","no","WebGPU"
-"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=1","support","0","no","WebGPU"
-"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=1","support","0","no","WebGPU"
-"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=1","support","0","no","WebGPU"
-"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=1","support","0","no","WebGPU"
-"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=1","support","0","no","WebGPU"
-"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=2","support","0","no","WebGPU"
-"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=2","support","0","no","WebGPU"
-"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=2","support","0","no","WebGPU"
-"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=2","support","0","no","WebGPU"
-"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=2","support","0","no","WebGPU"
-"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=2","support","0","no","WebGPU"
-"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=2","support","0","no","WebGPU"
-"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=2","support","0","no","WebGPU"
-"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=3","support","0","no","WebGPU"
-"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=3","support","0","no","WebGPU"
-"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=3","support","0","no","WebGPU"
-"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=3","support","0","no","WebGPU"
-"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=3","support","0","no","WebGPU"
-"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=3","support","0","no","WebGPU"
-"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=3","support","0","no","WebGPU"
-"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=3","support","0","no","WebGPU"
+"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=0","support","1","yes","WebGPU"
+"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=0","support","1","yes","WebGPU"
+"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=0","support","1","yes","WebGPU"
+"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=0","support","1","yes","WebGPU"
+"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=0","support","1","yes","WebGPU"
+"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=0","support","1","yes","WebGPU"
+"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=0","support","1","yes","WebGPU"
+"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=0","support","1","yes","WebGPU"
+"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=1","support","1","yes","WebGPU"
+"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=1","support","1","yes","WebGPU"
+"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=1","support","1","yes","WebGPU"
+"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=1","support","1","yes","WebGPU"
+"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=1","support","1","yes","WebGPU"
+"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=1","support","1","yes","WebGPU"
+"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=1","support","1","yes","WebGPU"
+"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=1","support","1","yes","WebGPU"
+"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=2","support","1","yes","WebGPU"
+"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=2","support","1","yes","WebGPU"
+"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=2","support","1","yes","WebGPU"
+"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=2","support","1","yes","WebGPU"
+"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=2","support","1","yes","WebGPU"
+"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=2","support","1","yes","WebGPU"
+"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=2","support","1","yes","WebGPU"
+"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=2","support","1","yes","WebGPU"
+"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=3","support","1","yes","WebGPU"
+"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=3","support","1","yes","WebGPU"
+"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=3","support","1","yes","WebGPU"
+"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=3","support","1","yes","WebGPU"
+"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=3","support","1","yes","WebGPU"
+"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=3","support","1","yes","WebGPU"
+"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=3","support","1","yes","WebGPU"
+"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=3","support","1","yes","WebGPU"
 "WebGPU: WebGPU","ARGSORT","type=f32,ne=[3,1,1,1],order=0","support","1","yes","WebGPU"
 "WebGPU: WebGPU","ARGSORT","type=f32,ne=[4,1,1,1],order=0","support","1","yes","WebGPU"
 "WebGPU: WebGPU","ARGSORT","type=f32,ne=[7,1,1,1],order=0","support","1","yes","WebGPU"
index 369475eaf50ca2916b769df25fccea38d8ee6710..17c5e0fb51f77fcb35d2854333fbfefa325282ca 100644 (file)
@@ -173,6 +173,22 @@ struct ggml_webgpu_scale_pipeline_key_hash {
     }
 };
 
+/** Concat **/
+
+struct ggml_webgpu_concat_pipeline_key {
+    int type;
+
+    bool operator==(const ggml_webgpu_concat_pipeline_key & other) const { return type == other.type; }
+};
+
+struct ggml_webgpu_concat_pipeline_key_hash {
+    size_t operator()(const ggml_webgpu_concat_pipeline_key & key) const {
+        size_t seed = 0;
+        ggml_webgpu_hash_combine(seed, key.type);
+        return seed;
+    }
+};
+
 /** Binary **/
 
 struct ggml_webgpu_binary_pipeline_key {
@@ -403,6 +419,8 @@ class ggml_webgpu_shader_lib {
         pad_pipelines;                                                 // circular/non-circular
     std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>
         binary_pipelines;                                              // type/op/inplace/overlap
+    std::unordered_map<ggml_webgpu_concat_pipeline_key, webgpu_pipeline, ggml_webgpu_concat_pipeline_key_hash>
+        concat_pipelines;                                              // type
     std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
         flash_attn_pipelines;
     std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key,
@@ -1096,6 +1114,43 @@ class ggml_webgpu_shader_lib {
         return binary_pipelines[key];
     }
 
+    webgpu_pipeline get_concat_pipeline(const ggml_webgpu_shader_lib_context & context) {
+        ggml_webgpu_concat_pipeline_key key = {
+            .type = context.dst->type,
+        };
+
+        auto it = concat_pipelines.find(key);
+        if (it != concat_pipelines.end()) {
+            return it->second;
+        }
+
+        std::vector<std::string> defines;
+        std::string variant = "concat";
+
+        switch (key.type) {
+            case GGML_TYPE_F32:
+                defines.push_back("TYPE_F32");
+                variant += "_f32";
+                break;
+            case GGML_TYPE_I32:
+                defines.push_back("TYPE_I32");
+                variant += "_i32";
+                break;
+            default:
+                GGML_ABORT("Unsupported type for concat shader");
+        }
+
+        defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+        auto processed = preprocessor.preprocess(wgsl_concat, defines);
+        auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
+        decisions->wg_size = context.max_wg_size;
+        webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
+        pipeline.context = decisions;
+        concat_pipelines[key] = pipeline;
+        return concat_pipelines[key];
+    }
+
     webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) {
         const bool has_mask  = context.src3 != nullptr;
         const bool has_sinks = context.src4 != nullptr;
index 19451618ec5c8d1910fae50b2fa8f1ed326f84e8..334919e589fa648e16792c4554ba4e7568de1c85 100644 (file)
@@ -1484,6 +1484,68 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
     return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
 }
 
+static webgpu_command ggml_webgpu_concat(webgpu_context & ctx,
+                                         ggml_tensor * src0,
+                                         ggml_tensor * src1,
+                                         ggml_tensor * dst) {
+    uint32_t ne = (uint32_t) ggml_nelements(dst);
+    uint32_t dim = (uint32_t) dst->op_params[0];
+
+    std::vector<uint32_t> params = {
+        ne,
+        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
+        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
+        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+        (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
+        (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
+        (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
+        (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
+        (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
+        (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
+        (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
+        (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
+        (uint32_t) dst->ne[0],
+        (uint32_t) dst->ne[1],
+        (uint32_t) dst->ne[2],
+        (uint32_t) dst->ne[3],
+        dim,
+        (uint32_t)src0->ne[dim]
+    };
+
+    std::vector<wgpu::BindGroupEntry> entries = {
+        {
+            .binding = 0,
+            .buffer = ggml_webgpu_tensor_buf(src0),
+            .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
+            .size = ggml_webgpu_tensor_binding_size(ctx, src0)
+        },
+        {
+            .binding = 1,
+            .buffer = ggml_webgpu_tensor_buf(src1),
+            .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
+            .size = ggml_webgpu_tensor_binding_size(ctx, src1)
+        },
+        {
+            .binding = 2,
+            .buffer = ggml_webgpu_tensor_buf(dst),
+            .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
+            .size = ggml_webgpu_tensor_binding_size(ctx, dst)
+        }
+    };
+
+    ggml_webgpu_shader_lib_context shader_lib_ctx = {
+        .src0        = src0,
+        .src1        = src1,
+        .dst         = dst,
+        .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+    };
+
+    webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx);
+    auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
+    uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
+    return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
+}
+
 static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
     int inplace = ggml_webgpu_tensor_equal(src, dst);
 
@@ -2068,6 +2130,8 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
         case GGML_OP_MUL:
         case GGML_OP_DIV:
             return ggml_webgpu_binary_op(ctx, src0, src1, node);
+        case GGML_OP_CONCAT:
+            return ggml_webgpu_concat(ctx, src0, src1, node);
         case GGML_OP_RMS_NORM:
             return ggml_webgpu_rms_norm(ctx, src0, node);
         case GGML_OP_ROPE:
@@ -2894,6 +2958,9 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
             supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) &&
                           (src1->type == op->type);
             break;
+        case GGML_OP_CONCAT:
+            supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32);
+            break;
         case GGML_OP_CPY:
         case GGML_OP_CONT:
             supports_op = ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl
new file mode 100644 (file)
index 0000000..a22d245
--- /dev/null
@@ -0,0 +1,75 @@
+struct Params {
+    ne: u32,
+
+    offset_src0: u32,
+    offset_src1: u32,
+    offset_dst: u32,
+
+    stride_src0_0: u32,
+    stride_src0_1: u32,
+    stride_src0_2: u32,
+    stride_src0_3: u32,
+
+    stride_src1_0: u32,
+    stride_src1_1: u32,
+    stride_src1_2: u32,
+    stride_src1_3: u32,
+
+    ne0: u32,
+    ne1: u32,
+    ne2: u32,
+    ne3: u32,
+
+    dim: u32,
+    src0_nedim: u32
+};
+
+#ifdef TYPE_F32
+#define DataType f32
+#endif
+#ifdef TYPE_I32
+#define DataType i32
+#endif
+
+@group(0) @binding(0)
+var<storage, read_write> src0: array<DataType>;
+
+@group(0) @binding(1)
+var<storage, read_write> src1 : array<DataType>;
+
+@group(0) @binding(2)
+var<storage, read_write> dst: array<DataType>;
+
+@group(0) @binding(3)
+var<uniform> params: Params;
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
+
+    if (gid.x < params.ne) {
+        var i = gid.x;
+        let i3 = i / (params.ne2 * params.ne1 * params.ne0);
+        i = i % (params.ne2 * params.ne1 * params.ne0);
+        let i2 = i / (params.ne1 * params.ne0);
+        i = i % (params.ne1 * params.ne0);
+        let i1 = i / params.ne0;
+        let i0 = i % params.ne0;
+
+        var ni = array<u32, 4>(i0, i1, i2, i3);
+
+        if (ni[params.dim] < params.src0_nedim) {
+            let src_i = ni[0] * params.stride_src0_0 +
+                             ni[1] * params.stride_src0_1 +
+                             ni[2] * params.stride_src0_2 +
+                             ni[3] * params.stride_src0_3;
+            dst[params.offset_dst + gid.x] = src0[params.offset_src0 + src_i];
+        } else {
+            ni[params.dim] -= params.src0_nedim;
+            let src_i = ni[0] * params.stride_src1_0 +
+                             ni[1] * params.stride_src1_1 +
+                             ni[2] * params.stride_src1_2 +
+                             ni[3] * params.stride_src1_3;
+            dst[params.offset_dst + gid.x] = src1[params.offset_src1 + src_i];
+        }
+    }
+}