From: Reese Levine Date: Thu, 19 Mar 2026 15:45:28 +0000 (-0700) Subject: ggml webgpu: ops support for qwen3.5 (SET, TRI_SOLVE, SSM_CONV, GATED_DELTA_NET)... X-Git-Tag: v0.9.9~37 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=26c9a1f5e143d8f179c1b9d1d122d66de4857e63;p=pkg%2Fggml%2Fsources%2Fggml ggml webgpu: ops support for qwen3.5 (SET, TRI_SOLVE, SSM_CONV, GATED_DELTA_NET) + GET_ROWS optimization (llama/20687) * Implement l2_norm, set, tri * Add DIAG/SOLVE_TRI * Add SSM_CONV * Better get_rows and gated_delta_net to support qwen3.5 * Clean up, update ops.md * Fix binding_index type for wasm * Fix read write annotations * cleanups --- diff --git a/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 9d16abf2..59861ac1 100644 --- a/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -95,6 +95,11 @@ struct ggml_webgpu_generic_shader_decisions { uint32_t wg_size = 0; }; +struct ggml_webgpu_ssm_conv_shader_decisions { + uint32_t block_size; + uint32_t tokens_per_wg; +}; + /** Argsort **/ struct ggml_webgpu_argsort_shader_lib_context { @@ -131,6 +136,26 @@ struct ggml_webgpu_set_rows_shader_decisions { uint32_t wg_size; }; +/** Set **/ + +struct ggml_webgpu_set_pipeline_key { + ggml_type type; + bool inplace; + + bool operator==(const ggml_webgpu_set_pipeline_key & other) const { + return type == other.type && inplace == other.inplace; + } +}; + +struct ggml_webgpu_set_pipeline_key_hash { + size_t operator()(const ggml_webgpu_set_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.inplace); + return seed; + } +}; + /** Get Rows **/ struct ggml_webgpu_get_rows_pipeline_key { @@ -186,6 +211,67 @@ struct ggml_webgpu_pad_pipeline_key_hash { } }; +/** Solve Tri **/ +struct ggml_webgpu_solve_tri_pipeline_key { + int type; + int n; + int k; + + bool operator==(const ggml_webgpu_solve_tri_pipeline_key & other) const { + return type == other.type && n == other.n && k == other.k; + } +}; + +struct ggml_webgpu_solve_tri_pipeline_key_hash { + size_t operator()(const ggml_webgpu_solve_tri_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.n); + ggml_webgpu_hash_combine(seed, key.k); + return seed; + } +}; + +/** SSM Conv **/ +struct ggml_webgpu_ssm_conv_pipeline_key { + int type; + int vectorized; + + bool operator==(const ggml_webgpu_ssm_conv_pipeline_key & other) const { + return type == other.type && vectorized == other.vectorized; + } +}; + +/** Gated Delta Net **/ +struct ggml_webgpu_gated_delta_net_pipeline_key { + int type; + int s_v; + int kda; + + bool operator==(const ggml_webgpu_gated_delta_net_pipeline_key & other) const { + return type == other.type && s_v == other.s_v && kda == other.kda; + } +}; + +struct ggml_webgpu_gated_delta_net_pipeline_key_hash { + size_t operator()(const ggml_webgpu_gated_delta_net_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.s_v); + ggml_webgpu_hash_combine(seed, key.kda); + return seed; + } +}; + +struct ggml_webgpu_ssm_conv_pipeline_key_hash { + size_t operator()(const ggml_webgpu_ssm_conv_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.vectorized); + return seed; + } +}; + /** Scale **/ struct ggml_webgpu_scale_pipeline_key { @@ -466,14 +552,22 @@ class ggml_webgpu_shader_lib { unary_pipelines; // type/op/inplace std::unordered_map scale_pipelines; // inplace + std::unordered_map + solve_tri_pipelines; // type + std::unordered_map + ssm_conv_pipelines; // type/vectorized + std::unordered_map + gated_delta_net_pipelines; // type/S_v/kda std::unordered_map - pad_pipelines; // circular/non-circular + pad_pipelines; // circular/non-circular std::unordered_map - binary_pipelines; // type/op/inplace/overlap + binary_pipelines; // type/op/inplace/overlap std::unordered_map - concat_pipelines; // type + concat_pipelines; // type std::unordered_map - repeat_pipelines; // type + repeat_pipelines; // type std::unordered_map flash_attn_pipelines; std::unordered_map set_rows_pipelines; + std::unordered_map set_pipelines; public: ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; } @@ -519,11 +614,11 @@ class ggml_webgpu_shader_lib { switch (key.op) { case GGML_OP_RMS_NORM: - defines.push_back("OP_RMS_NORM"); + defines.push_back("RMS_NORM"); variant = "rms_norm"; break; case GGML_OP_L2_NORM: - defines.push_back("OP_L2_NORM"); + defines.push_back("L2_NORM"); variant = "l2_norm"; break; default: @@ -535,8 +630,9 @@ class ggml_webgpu_shader_lib { variant += "_inplace"; } - defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - + const uint32_t row_norm_wg_size = 128u; + uint32_t wg_size = std::min(context.max_wg_size, row_norm_wg_size); + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); auto processed = preprocessor.preprocess(wgsl_row_norm, defines); row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant); return row_norm_pipelines[key]; @@ -609,6 +705,46 @@ class ggml_webgpu_shader_lib { return set_rows_pipelines[key]; } + webgpu_pipeline get_set_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_set_pipeline_key key = { .type = context.dst->type, .inplace = context.inplace }; + + auto it = set_pipelines.find(key); + if (it != set_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "set"; + + switch (key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; + break; + case GGML_TYPE_I32: + defines.push_back("TYPE_I32"); + variant += "_i32"; + break; + default: + GGML_ABORT("Unsupported type for set shader"); + } + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_set, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + set_pipelines[key] = pipeline; + return set_pipelines[key]; + } + webgpu_pipeline get_cumsum_pipeline(const ggml_webgpu_shader_lib_context & context) { auto it = cumsum_pipelines.find(1); if (it != cumsum_pipelines.end()) { @@ -695,6 +831,7 @@ class ggml_webgpu_shader_lib { switch (key.src_type) { case GGML_TYPE_F32: + defines.push_back("FLOAT_PARALLEL"); if (key.vectorized) { defines.push_back("F32_VEC"); defines.push_back("SRC_TYPE=vec4"); @@ -709,6 +846,7 @@ class ggml_webgpu_shader_lib { variant += "_f32"; break; case GGML_TYPE_F16: + defines.push_back("FLOAT_PARALLEL"); defines.push_back("F16"); defines.push_back("SRC_TYPE=f16"); defines.push_back("DST_TYPE=f32"); @@ -716,6 +854,7 @@ class ggml_webgpu_shader_lib { variant += "_f16"; break; case GGML_TYPE_I32: + defines.push_back("FLOAT_PARALLEL"); defines.push_back("I32"); defines.push_back("SRC_TYPE=i32"); defines.push_back("DST_TYPE=i32"); @@ -794,6 +933,128 @@ class ggml_webgpu_shader_lib { return scale_pipelines[key]; } + webgpu_pipeline get_solve_tri_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_solve_tri_pipeline_key key = { + .type = context.dst->type, + .n = (int) context.src0->ne[0], + .k = (int) context.src1->ne[0], + }; + + auto it = solve_tri_pipelines.find(key); + if (it != solve_tri_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "solve_tri"; + + switch (key.type) { + case GGML_TYPE_F32: + variant += "_f32"; + break; + default: + GGML_ABORT("Unsupported type for solve_tri shader"); + } + + const uint32_t wg_size = std::min((uint32_t) key.n, context.max_wg_size); + const uint32_t k_tile = wg_size; + const uint32_t bytes_per_row = ((uint32_t) key.n + wg_size) * GGML_WEBGPU_F32_SIZE_BYTES; + const uint32_t batch_n = (uint32_t) (context.wg_mem_limit_bytes / bytes_per_row); + + defines.push_back(std::string("N=") + std::to_string(key.n)); + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + defines.push_back(std::string("K_TILE=") + std::to_string(k_tile)); + defines.push_back(std::string("BATCH_N=") + std::to_string(batch_n)); + + auto processed = preprocessor.preprocess(wgsl_solve_tri, defines); + auto decisions = std::make_shared(); + decisions->wg_size = wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + solve_tri_pipelines[key] = pipeline; + return solve_tri_pipelines[key]; + } + + webgpu_pipeline get_ssm_conv_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_ssm_conv_pipeline_key key = { + .type = context.dst->type, + .vectorized = context.src1->ne[0] == 4, + }; + + auto it = ssm_conv_pipelines.find(key); + if (it != ssm_conv_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "ssm_conv"; + + switch (key.type) { + case GGML_TYPE_F32: + variant += "_f32"; + break; + default: + GGML_ABORT("Unsupported type for ssm_conv shader"); + } + + if (key.vectorized) { + defines.push_back("VECTORIZED"); + variant += "_vec4"; + } + + constexpr uint32_t block_size = 32u; + constexpr uint32_t tokens_per_wg = 8u; + + defines.push_back("BLOCK_SIZE=" + std::to_string(block_size) + "u"); + defines.push_back("TOKENS_PER_WG=" + std::to_string(tokens_per_wg) + "u"); + + auto processed = preprocessor.preprocess(wgsl_ssm_conv, defines); + auto decisions = std::make_shared(); + decisions->block_size = block_size; + decisions->tokens_per_wg = tokens_per_wg; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + ssm_conv_pipelines[key] = pipeline; + return ssm_conv_pipelines[key]; + } + + webgpu_pipeline get_gated_delta_net_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_gated_delta_net_pipeline_key key = { + .type = context.dst->type, + .s_v = (int) context.src2->ne[0], + .kda = context.src3->ne[0] == context.src2->ne[0], + }; + + auto it = gated_delta_net_pipelines.find(key); + if (it != gated_delta_net_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "gated_delta_net"; + + switch (key.type) { + case GGML_TYPE_F32: + variant += "_f32"; + break; + default: + GGML_ABORT("Unsupported type for gated_delta_net shader"); + } + + if (key.kda) { + defines.push_back("KDA"); + variant += "_kda"; + } + + defines.push_back("S_V=" + std::to_string(key.s_v) + "u"); + defines.push_back("WG_SIZE=" + std::to_string(key.s_v) + "u"); + + auto processed = preprocessor.preprocess(wgsl_gated_delta_net, defines); + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + gated_delta_net_pipelines[key] = pipeline; + return gated_delta_net_pipelines[key]; + } + webgpu_pipeline get_pad_pipeline(const ggml_webgpu_shader_lib_context & context) { ggml_webgpu_pad_pipeline_key key = { .circular = ggml_get_op_params_i32(context.dst, 8) != 0 }; diff --git a/src/ggml-webgpu/ggml-webgpu.cpp b/src/ggml-webgpu/ggml-webgpu.cpp index f7973df6..5e16f84d 100644 --- a/src/ggml-webgpu/ggml-webgpu.cpp +++ b/src/ggml-webgpu/ggml-webgpu.cpp @@ -880,6 +880,68 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g params, entries, wg_x); } +static webgpu_command ggml_webgpu_set(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { + const bool inplace = ggml_webgpu_tensor_equal(src0, dst); + + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .src1 = src1, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + .inplace = inplace, + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_set_pipeline(shader_lib_ctx); + + auto * decisions = static_cast(pipeline.context.get()); + + const uint32_t ne = inplace ? (uint32_t) ggml_nelements(src1) : (uint32_t) ggml_nelements(dst); + const uint32_t dst_type_size = (uint32_t) ggml_type_size(dst->type); + + std::vector params = { + ne, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (((const int32_t *) dst->op_params)[3] / dst_type_size), + + (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + + 1u, + (uint32_t) (((const int32_t *) dst->op_params)[0] / dst_type_size), + (uint32_t) (((const int32_t *) dst->op_params)[1] / dst_type_size), + (uint32_t) (((const int32_t *) dst->op_params)[2] / dst_type_size), + + (uint32_t) src1->ne[0], + (uint32_t) src1->ne[1], + (uint32_t) src1->ne[2], + (uint32_t) src1->ne[3], + }; + + std::vector entries; + uint32_t binding_index = 0; + if (!inplace) { + entries.push_back({ .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0) }); + binding_index++; + } + entries.push_back({ .binding = binding_index, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1) }); + entries.push_back({ .binding = binding_index + 1, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + + uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); +} + static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup @@ -935,6 +997,208 @@ static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, g return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } +static webgpu_command ggml_webgpu_solve_tri(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .src1 = src1, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_solve_tri_pipeline(shader_lib_ctx); + + auto * decisions = static_cast(pipeline.context.get()); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + + (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + + (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + + (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + + (uint32_t) src1->ne[0], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, + { .binding = 2, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + }; + + const uint32_t wg_x = CEIL_DIV((uint32_t) src1->ne[0], decisions->wg_size); + const uint32_t wg_y = (uint32_t) (dst->ne[2] * dst->ne[3]); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y); +} + +static webgpu_command ggml_webgpu_ssm_conv(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .src1 = src1, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_conv_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); + + const uint32_t token_tiles = CEIL_DIV((uint32_t) dst->ne[1], decisions->tokens_per_wg); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + + (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + + (uint32_t) src1->ne[0], + (uint32_t) src0->ne[1], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + token_tiles, + }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, + { .binding = 2, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + }; + + const uint32_t wg_x = CEIL_DIV((uint32_t) src0->ne[1], decisions->block_size); + const uint32_t wg_y = token_tiles * (uint32_t) dst->ne[2]; + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y); +} + +static webgpu_command ggml_webgpu_gated_delta_net(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * src3, + ggml_tensor * src4, + ggml_tensor * src5, + ggml_tensor * dst) { + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .src1 = src1, + .src2 = src2, + .src3 = src3, + .src4 = src4, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_gated_delta_net_pipeline(shader_lib_ctx); + + const uint32_t s_v = (uint32_t) src2->ne[0]; + const uint32_t h = (uint32_t) src2->ne[1]; + const uint32_t n_tokens = (uint32_t) src2->ne[2]; + const uint32_t n_seqs = (uint32_t) src2->ne[3]; + const float scale = 1.0f / sqrtf((float) s_v); + uint32_t scale_u32; + memcpy(&scale_u32, &scale, sizeof(scale_u32)); + + std::vector params = { + h, + n_tokens, + n_seqs, + s_v * h * n_tokens * n_seqs, + + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + + (uint32_t) (src2->nb[1] / ggml_type_size(src2->type)), + (uint32_t) (src2->nb[2] / ggml_type_size(src2->type)), + (uint32_t) (src2->nb[3] / ggml_type_size(src2->type)), + + (uint32_t) (src4->nb[1] / ggml_type_size(src4->type)), + (uint32_t) (src4->nb[2] / ggml_type_size(src4->type)), + (uint32_t) (src4->nb[3] / ggml_type_size(src4->type)), + + (uint32_t) src0->ne[1], + (uint32_t) (src2->ne[3] / src0->ne[3]), + scale_u32, + }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, + { .binding = 2, + .buffer = ggml_webgpu_tensor_buf(src2), + .offset = ggml_webgpu_tensor_align_offset(ctx, src2), + .size = ggml_webgpu_tensor_binding_size(ctx, src2) }, + { .binding = 3, + .buffer = ggml_webgpu_tensor_buf(src3), + .offset = ggml_webgpu_tensor_align_offset(ctx, src3), + .size = ggml_webgpu_tensor_binding_size(ctx, src3) }, + { .binding = 4, + .buffer = ggml_webgpu_tensor_buf(src4), + .offset = ggml_webgpu_tensor_align_offset(ctx, src4), + .size = ggml_webgpu_tensor_binding_size(ctx, src4) }, + { .binding = 5, + .buffer = ggml_webgpu_tensor_buf(src5), + .offset = ggml_webgpu_tensor_align_offset(ctx, src5), + .size = ggml_webgpu_tensor_binding_size(ctx, src5) }, + { .binding = 6, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) } + }; + + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, h, n_seqs); +} + static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, @@ -1016,6 +1280,8 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) { + const bool float_parallel = src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16 || src->type == GGML_TYPE_I32; + ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src, .src1 = nullptr, @@ -1060,7 +1326,10 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], decisions->wg_size); + uint32_t blocks_per_row = (uint32_t) (dst->ne[0] / (src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0 ? 4 : 1)); + uint32_t total_rows = (uint32_t) (dst->ne[1] * dst->ne[2] * dst->ne[3]); + uint32_t total_threads = float_parallel ? blocks_per_row * total_rows : total_rows; + uint32_t wg_x = CEIL_DIV(total_threads, decisions->wg_size); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } @@ -1632,7 +1901,7 @@ static webgpu_command ggml_webgpu_row_norm(webgpu_context & ctx, ggml_tensor * s ggml_webgpu_shader_lib_context shader_lib_ctx = { .src0 = src, .dst = dst, - .max_wg_size = WEBGPU_ROW_SPLIT_WG_SIZE, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, .inplace = inplace, }; @@ -2176,6 +2445,8 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_CPY: case GGML_OP_CONT: return ggml_webgpu_cpy(ctx, src0, node); + case GGML_OP_SET: + return ggml_webgpu_set(ctx, src0, src1, node); case GGML_OP_SET_ROWS: return ggml_webgpu_set_rows(ctx, src0, src1, node); case GGML_OP_GET_ROWS: @@ -2219,6 +2490,12 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_DIAG: case GGML_OP_TRI: return ggml_webgpu_unary_op(ctx, src0, node); + case GGML_OP_SOLVE_TRI: + return ggml_webgpu_solve_tri(ctx, src0, src1, node); + case GGML_OP_SSM_CONV: + return ggml_webgpu_ssm_conv(ctx, src0, src1, node); + case GGML_OP_GATED_DELTA_NET: + return ggml_webgpu_gated_delta_net(ctx, src0, src1, src2, node->src[3], node->src[4], node->src[5], node); case GGML_OP_PAD: return ggml_webgpu_pad(ctx, src0, node); case GGML_OP_ARGMAX: @@ -2957,7 +3234,7 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm /* .is_host = */ NULL, // defaults to false }, /* .device = */ - dev, + dev, /* .context = */ NULL }; @@ -3040,6 +3317,10 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) || (op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32); break; + case GGML_OP_SET: + supports_op = src0->type == src1->type && src0->type == op->type && + (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_I32); + break; case GGML_OP_SET_ROWS: supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32) && src0->type == GGML_TYPE_F32 && (src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32)); @@ -3180,6 +3461,27 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const } } break; + case GGML_OP_TRI: + supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; + break; + case GGML_OP_DIAG: + supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; + break; + case GGML_OP_SOLVE_TRI: + supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32; + break; + case GGML_OP_SSM_CONV: + supports_op = op->type == GGML_TYPE_F32; + break; + case GGML_OP_GATED_DELTA_NET: + { + const uint32_t s_v = (uint32_t) src2->ne[0]; + supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && + src2->type == GGML_TYPE_F32 && op->src[3]->type == GGML_TYPE_F32 && + op->src[4]->type == GGML_TYPE_F32 && op->src[5]->type == GGML_TYPE_F32 && + s_v <= ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup; + } + break; case GGML_OP_CLAMP: supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); break; @@ -3201,12 +3503,6 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_COS: supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); break; - case GGML_OP_DIAG: - supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); - break; - case GGML_OP_TRI: - supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); - break; case GGML_OP_PAD: supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; break; diff --git a/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl b/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl new file mode 100644 index 00000000..f9d98fda --- /dev/null +++ b/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl @@ -0,0 +1,132 @@ +@group(0) @binding(0) +var src_q: array; + +@group(0) @binding(1) +var src_k: array; + +@group(0) @binding(2) +var src_v: array; + +@group(0) @binding(3) +var src_g: array; + +@group(0) @binding(4) +var src_beta: array; + +@group(0) @binding(5) +var src_state: array; + +@group(0) @binding(6) +var dst: array; + +struct Params { + h: u32, + n_tokens: u32, + n_seqs: u32, + s_off: u32, + + sq1: u32, + sq2: u32, + sq3: u32, + + sv1: u32, + sv2: u32, + sv3: u32, + + sb1: u32, + sb2: u32, + sb3: u32, + + neq1: u32, + rq3: u32, + scale: f32, +}; + +@group(0) @binding(7) +var params: Params; + +var sh_k: array; +var sh_q: array; +var sh_g: array; + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(workgroup_id) workgroup_id: vec3, + @builtin(local_invocation_id) local_id: vec3 +) { + let head_id = workgroup_id.x; + let seq_id = workgroup_id.y; + let col = local_id.x; + + let iq1 = head_id % params.neq1; + let iq3 = seq_id / params.rq3; + + let state_size = S_V * S_V; + let state_base = (seq_id * params.h + head_id) * state_size; + + var state: array; + for (var i = 0u; i < S_V; i++) { + state[i] = src_state[state_base + col * S_V + i]; + } + + var attn_off = (seq_id * params.n_tokens * params.h + head_id) * S_V; + + for (var t = 0u; t < params.n_tokens; t++) { + let q_off = iq3 * params.sq3 + t * params.sq2 + iq1 * params.sq1; + let k_off = q_off; + let v_off = seq_id * params.sv3 + t * params.sv2 + head_id * params.sv1; + let gb_off = seq_id * params.sb3 + t * params.sb2 + head_id * params.sb1; + + sh_q[col] = src_q[q_off + col]; + sh_k[col] = src_k[k_off + col]; + +#ifdef KDA + let g_base = gb_off * S_V; + sh_g[col] = exp(src_g[g_base + col]); +#endif + + workgroupBarrier(); + + let v_val = src_v[v_off + col]; + let beta_val = src_beta[gb_off]; + + var kv_col = 0.0; + var delta_col = 0.0; + var attn_col = 0.0; + +#ifdef KDA + for (var i = 0u; i < S_V; i++) { + kv_col += (sh_g[i] * state[i]) * sh_k[i]; + } + + delta_col = (v_val - kv_col) * beta_val; + + for (var i = 0u; i < S_V; i++) { + state[i] = sh_g[i] * state[i] + sh_k[i] * delta_col; + attn_col += state[i] * sh_q[i]; + } +#else + let g_val = exp(src_g[gb_off]); + + for (var i = 0u; i < S_V; i++) { + kv_col += state[i] * sh_k[i]; + } + + delta_col = (v_val - g_val * kv_col) * beta_val; + + for (var i = 0u; i < S_V; i++) { + state[i] = g_val * state[i] + sh_k[i] * delta_col; + attn_col += state[i] * sh_q[i]; + } +#endif + + dst[attn_off + col] = attn_col * params.scale; + attn_off += S_V * params.h; + + workgroupBarrier(); + } + + for (var i = 0u; i < S_V; i++) { + dst[params.s_off + state_base + col * S_V + i] = state[i]; + } +} diff --git a/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl b/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl index b10800e3..d9eb6a35 100644 --- a/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +++ b/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl @@ -640,6 +640,35 @@ var params: Params; @compute @workgroup_size(WG_SIZE) fn main(@builtin(global_invocation_id) gid: vec3) { +#ifdef FLOAT_PARALLEL + let blocks_per_row = params.ne0 / BLOCK_SIZE; + let row_count = params.n_rows * params.ne2 * params.ne3; + + if (gid.x >= blocks_per_row * row_count) { + return; + } + + let block_idx = gid.x % blocks_per_row; + var row_idx = gid.x / blocks_per_row; + let i_dst3 = row_idx / (params.ne2 * params.n_rows); + + row_idx = row_idx % (params.ne2 * params.n_rows); + let i_dst2 = row_idx / params.n_rows; + let i_dst1 = row_idx % params.n_rows; + + let i_idx2 = i_dst3 % params.idx2; + let i_idx1 = i_dst2 % params.idx1; + let i_idx0 = i_dst1; + + let i_idx = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2; + + let idx_val = u32(idx[i_idx]); + + let i_src_row = params.offset_src + idx_val * params.stride_src1 + i_dst2 * params.stride_src2 + i_dst3 * params.stride_src3; + let i_dst_row = params.offset_dst + i_dst1 * params.stride_dst1 + i_dst2 * params.stride_dst2 + i_dst3 * params.stride_dst3; + + copy_elements(i_src_row, i_dst_row, block_idx); +#else if (gid.x >= params.n_rows * params.ne2 * params.ne3) { return; } @@ -664,5 +693,5 @@ fn main(@builtin(global_invocation_id) gid: vec3) { for (var i: u32 = 0; i < params.ne0/BLOCK_SIZE; i++) { copy_elements(i_src_row, i_dst_row, i); } +#endif } - diff --git a/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl b/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl index 77779449..bd8d32bd 100644 --- a/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +++ b/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl @@ -81,11 +81,12 @@ fn main(@builtin(workgroup_id) wid: vec3, } sum = scratch[0]; -#ifdef OP_RMS_NORM +#ifdef RMS_NORM let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps); -#elif OP_L2_NORM +#elif defined(L2_NORM) let scale = 1.0/max(sqrt(sum), params.eps); #endif + col = lid.x; for (var j: u32 = 0; j < elems; j++) { if (col >= params.ne0) { diff --git a/src/ggml-webgpu/wgsl-shaders/set.wgsl b/src/ggml-webgpu/wgsl-shaders/set.wgsl new file mode 100644 index 00000000..0a7ae9bd --- /dev/null +++ b/src/ggml-webgpu/wgsl-shaders/set.wgsl @@ -0,0 +1,109 @@ +#ifdef TYPE_I32 +#define TYPE i32 +#else +#define TYPE f32 +#endif + +#ifndef INPLACE +@group(0) @binding(0) +var src0: array; +#define SRC1_BINDING 1 +#else +#define SRC1_BINDING 0 +#endif + +#define DST_BINDING SRC1_BINDING + 1 +#define PARAMS_BINDING SRC1_BINDING + 2 + +@group(0) @binding(SRC1_BINDING) +var src1: array; + +@group(0) @binding(DST_BINDING) +var dst: array; + +struct Params { + ne: u32, + offset_src0: u32, + offset_src1: u32, + offset_view: u32, + + stride_src10: u32, + stride_src11: u32, + stride_src12: u32, + stride_src13: u32, + + stride_dst10: u32, + stride_dst11: u32, + stride_dst12: u32, + stride_dst13: u32, + + src1_ne0: u32, + src1_ne1: u32, + src1_ne2: u32, + src1_ne3: u32, +}; + +@group(0) @binding(PARAMS_BINDING) +var params: Params; + +fn decode_src1_coords(idx: u32) -> vec4 { + var i = idx; + let plane = params.src1_ne2 * params.src1_ne1 * params.src1_ne0; + let i3 = i / plane; + i = i % plane; + let row = params.src1_ne1 * params.src1_ne0; + let i2 = i / row; + i = i % row; + let i1 = i / params.src1_ne0; + let i0 = i % params.src1_ne0; + return vec4(i0, i1, i2, i3); +} + +fn decode_view_coords(rel: u32) -> vec4 { + let i3 = rel / params.stride_dst13; + let rem3 = rel % params.stride_dst13; + let i2 = rem3 / params.stride_dst12; + let rem2 = rem3 % params.stride_dst12; + let i1 = rem2 / params.stride_dst11; + let i0 = rem2 % params.stride_dst11; + return vec4(i0, i1, i2, i3); +} + +fn view_rel_from_coords(coords: vec4) -> u32 { + return coords.x * params.stride_dst10 + coords.y * params.stride_dst11 + + coords.z * params.stride_dst12 + coords.w * params.stride_dst13; +} + +fn src1_idx_from_coords(coords: vec4) -> u32 { + return coords.x * params.stride_src10 + coords.y * params.stride_src11 + + coords.z * params.stride_src12 + coords.w * params.stride_src13; +} + +fn in_set_view(rel: u32, coords: vec4) -> bool { + return view_rel_from_coords(coords) == rel; +} + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.ne) { + return; + } + +#ifdef INPLACE + let coords = decode_src1_coords(gid.x); + + let src1_idx = params.offset_src1 + src1_idx_from_coords(coords); + let dst_idx = params.offset_view + view_rel_from_coords(coords); + + dst[dst_idx] = src1[src1_idx]; +#else + let rel = select(params.ne, gid.x - params.offset_view, gid.x >= params.offset_view); + let coords = decode_view_coords(rel); + + if (rel < params.stride_dst13 * params.src1_ne3 && in_set_view(rel, coords)) { + dst[gid.x] = src1[params.offset_src1 + src1_idx_from_coords(coords)]; + } else { + dst[gid.x] = src0[params.offset_src0 + gid.x]; + } +#endif +} diff --git a/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl b/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl new file mode 100644 index 00000000..9d5d902c --- /dev/null +++ b/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl @@ -0,0 +1,121 @@ +@group(0) @binding(0) +var src0: array; + +@group(0) @binding(1) +var src1: array; + +@group(0) @binding(2) +var dst: array; + +struct Params { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + + stride_src00: u32, + stride_src01: u32, + stride_src02: u32, + stride_src03: u32, + + stride_src10: u32, + stride_src11: u32, + stride_src12: u32, + stride_src13: u32, + + stride_dst0: u32, + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + k: u32, + ne2: u32, + ne3: u32, +}; + +@group(0) @binding(3) +var params: Params; + +var shA: array; +var shB: array; + +fn src0_idx(row: u32, col: u32, i2: u32, i3: u32) -> u32 { + return params.offset_src0 + + col * params.stride_src00 + + row * params.stride_src01 + + i2 * params.stride_src02 + + i3 * params.stride_src03; +} + +fn src1_idx(row: u32, col: u32, i2: u32, i3: u32) -> u32 { + return params.offset_src1 + + col * params.stride_src10 + + row * params.stride_src11 + + i2 * params.stride_src12 + + i3 * params.stride_src13; +} + +fn dst_idx(row: u32, col: u32, i2: u32, i3: u32) -> u32 { + return params.offset_dst + + col * params.stride_dst0 + + row * params.stride_dst1 + + i2 * params.stride_dst2 + + i3 * params.stride_dst3; +} + +@compute @workgroup_size(WG_SIZE) +fn main( + @builtin(workgroup_id) workgroup_id: vec3, + @builtin(local_invocation_id) local_id: vec3 +) { + let batch = workgroup_id.y; + let col = workgroup_id.x * WG_SIZE + local_id.x; + let i3 = batch / params.ne2; + let i2 = batch % params.ne2; + let active_lane = local_id.x < K_TILE; + let active_col = active_lane && col < params.k; + + var X: array; + + for (var row_base = 0u; row_base < N; row_base += BATCH_N) { + let cur_n = min(BATCH_N, N - row_base); + + for (var i = local_id.x; i < cur_n * N; i += WG_SIZE) { + let tile_row = i / N; + let tile_col = i % N; + shA[i] = src0[src0_idx(row_base + tile_row, tile_col, i2, i3)]; + } + + for (var i = local_id.x; i < cur_n * K_TILE; i += WG_SIZE) { + let tile_row = i / K_TILE; + let tile_col = i % K_TILE; + let global_col = workgroup_id.x * WG_SIZE + tile_col; + let sh_idx = tile_row * K_TILE + tile_col; + + if (global_col < params.k) { + shB[sh_idx] = src1[src1_idx(row_base + tile_row, global_col, i2, i3)]; + } else { + shB[sh_idx] = 0.0; + } + } + + workgroupBarrier(); + + if (active_col) { + for (var row_offset = 0u; row_offset < cur_n; row_offset++) { + let r = row_base + row_offset; + var b = shB[row_offset * K_TILE + local_id.x]; + let a_row = row_offset * N; + + for (var t = 0u; t < r; t++) { + b -= shA[a_row + t] * X[t]; + } + + let x = b / shA[a_row + r]; + X[r] = x; + dst[dst_idx(r, col, i2, i3)] = x; + } + } + + workgroupBarrier(); + } +} diff --git a/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl b/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl new file mode 100644 index 00000000..11511305 --- /dev/null +++ b/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl @@ -0,0 +1,65 @@ +@group(0) @binding(0) +var src0: array; + +@group(0) @binding(1) +var src1: array; + +@group(0) @binding(2) +var dst: array; + +struct Params { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + + stride_src01: u32, + stride_src02: u32, + stride_src11: u32, + + stride_dst0: u32, + stride_dst1: u32, + stride_dst2: u32, + + nc: u32, + nr: u32, + n_t: u32, + n_s: u32, + token_tiles: u32, +}; + +@group(0) @binding(3) +var params: Params; + +@compute @workgroup_size(BLOCK_SIZE, TOKENS_PER_WG) +fn main(@builtin(global_invocation_id) gid: vec3) { + let i1 = gid.x; + let tile_y = gid.y / TOKENS_PER_WG; + let local_token = gid.y % TOKENS_PER_WG; + let i3 = tile_y / params.token_tiles; + let token_tile = tile_y % params.token_tiles; + let i2 = token_tile * TOKENS_PER_WG + local_token; + + if (i1 >= params.nr || i2 >= params.n_t || i3 >= params.n_s) { + return; + } + + let src0_base = params.offset_src0 + i3 * params.stride_src02 + i2 + i1 * params.stride_src01; + let src1_base = params.offset_src1 + i1 * params.stride_src11; + + var sum = 0.0; + +#ifdef VECTORIZED + sum = + src0[src0_base + 0u] * src1[src1_base + 0u] + + src0[src0_base + 1u] * src1[src1_base + 1u] + + src0[src0_base + 2u] * src1[src1_base + 2u] + + src0[src0_base + 3u] * src1[src1_base + 3u]; +#else + for (var i0 = 0u; i0 < params.nc; i0++) { + sum += src0[src0_base + i0] * src1[src1_base + i0]; + } +#endif + + let dst_idx = params.offset_dst + i3 * params.stride_dst2 + i2 * params.stride_dst1 + i1 * params.stride_dst0; + dst[dst_idx] = sum; +}