uint32_t wg_size = 0;
};
+struct ggml_webgpu_ssm_conv_shader_decisions {
+ uint32_t block_size;
+ uint32_t tokens_per_wg;
+};
+
/** Argsort **/
struct ggml_webgpu_argsort_shader_lib_context {
uint32_t wg_size;
};
+/** Set **/
+
+struct ggml_webgpu_set_pipeline_key {
+ ggml_type type;
+ bool inplace;
+
+ bool operator==(const ggml_webgpu_set_pipeline_key & other) const {
+ return type == other.type && inplace == other.inplace;
+ }
+};
+
+struct ggml_webgpu_set_pipeline_key_hash {
+ size_t operator()(const ggml_webgpu_set_pipeline_key & key) const {
+ size_t seed = 0;
+ ggml_webgpu_hash_combine(seed, key.type);
+ ggml_webgpu_hash_combine(seed, key.inplace);
+ return seed;
+ }
+};
+
/** Get Rows **/
struct ggml_webgpu_get_rows_pipeline_key {
}
};
+/** Solve Tri **/
+struct ggml_webgpu_solve_tri_pipeline_key {
+ int type;
+ int n;
+ int k;
+
+ bool operator==(const ggml_webgpu_solve_tri_pipeline_key & other) const {
+ return type == other.type && n == other.n && k == other.k;
+ }
+};
+
+struct ggml_webgpu_solve_tri_pipeline_key_hash {
+ size_t operator()(const ggml_webgpu_solve_tri_pipeline_key & key) const {
+ size_t seed = 0;
+ ggml_webgpu_hash_combine(seed, key.type);
+ ggml_webgpu_hash_combine(seed, key.n);
+ ggml_webgpu_hash_combine(seed, key.k);
+ return seed;
+ }
+};
+
+/** SSM Conv **/
+struct ggml_webgpu_ssm_conv_pipeline_key {
+ int type;
+ int vectorized;
+
+ bool operator==(const ggml_webgpu_ssm_conv_pipeline_key & other) const {
+ return type == other.type && vectorized == other.vectorized;
+ }
+};
+
+/** Gated Delta Net **/
+struct ggml_webgpu_gated_delta_net_pipeline_key {
+ int type;
+ int s_v;
+ int kda;
+
+ bool operator==(const ggml_webgpu_gated_delta_net_pipeline_key & other) const {
+ return type == other.type && s_v == other.s_v && kda == other.kda;
+ }
+};
+
+struct ggml_webgpu_gated_delta_net_pipeline_key_hash {
+ size_t operator()(const ggml_webgpu_gated_delta_net_pipeline_key & key) const {
+ size_t seed = 0;
+ ggml_webgpu_hash_combine(seed, key.type);
+ ggml_webgpu_hash_combine(seed, key.s_v);
+ ggml_webgpu_hash_combine(seed, key.kda);
+ return seed;
+ }
+};
+
+struct ggml_webgpu_ssm_conv_pipeline_key_hash {
+ size_t operator()(const ggml_webgpu_ssm_conv_pipeline_key & key) const {
+ size_t seed = 0;
+ ggml_webgpu_hash_combine(seed, key.type);
+ ggml_webgpu_hash_combine(seed, key.vectorized);
+ return seed;
+ }
+};
+
/** Scale **/
struct ggml_webgpu_scale_pipeline_key {
unary_pipelines; // type/op/inplace
std::unordered_map<ggml_webgpu_scale_pipeline_key, webgpu_pipeline, ggml_webgpu_scale_pipeline_key_hash>
scale_pipelines; // inplace
+ std::unordered_map<ggml_webgpu_solve_tri_pipeline_key, webgpu_pipeline, ggml_webgpu_solve_tri_pipeline_key_hash>
+ solve_tri_pipelines; // type
+ std::unordered_map<ggml_webgpu_ssm_conv_pipeline_key, webgpu_pipeline, ggml_webgpu_ssm_conv_pipeline_key_hash>
+ ssm_conv_pipelines; // type/vectorized
+ std::unordered_map<ggml_webgpu_gated_delta_net_pipeline_key,
+ webgpu_pipeline,
+ ggml_webgpu_gated_delta_net_pipeline_key_hash>
+ gated_delta_net_pipelines; // type/S_v/kda
std::unordered_map<ggml_webgpu_pad_pipeline_key, webgpu_pipeline, ggml_webgpu_pad_pipeline_key_hash>
- pad_pipelines; // circular/non-circular
+ pad_pipelines; // circular/non-circular
std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>
- binary_pipelines; // type/op/inplace/overlap
+ binary_pipelines; // type/op/inplace/overlap
std::unordered_map<ggml_webgpu_concat_pipeline_key, webgpu_pipeline, ggml_webgpu_concat_pipeline_key_hash>
- concat_pipelines; // type
+ concat_pipelines; // type
std::unordered_map<ggml_webgpu_repeat_pipeline_key, webgpu_pipeline, ggml_webgpu_repeat_pipeline_key_hash>
- repeat_pipelines; // type
+ repeat_pipelines; // type
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
flash_attn_pipelines;
std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key,
std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
set_rows_pipelines;
+ std::unordered_map<ggml_webgpu_set_pipeline_key, webgpu_pipeline, ggml_webgpu_set_pipeline_key_hash> set_pipelines;
public:
ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; }
switch (key.op) {
case GGML_OP_RMS_NORM:
- defines.push_back("OP_RMS_NORM");
+ defines.push_back("RMS_NORM");
variant = "rms_norm";
break;
case GGML_OP_L2_NORM:
- defines.push_back("OP_L2_NORM");
+ defines.push_back("L2_NORM");
variant = "l2_norm";
break;
default:
variant += "_inplace";
}
- defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
-
+ const uint32_t row_norm_wg_size = 128u;
+ uint32_t wg_size = std::min(context.max_wg_size, row_norm_wg_size);
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
auto processed = preprocessor.preprocess(wgsl_row_norm, defines);
row_norm_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant);
return row_norm_pipelines[key];
return set_rows_pipelines[key];
}
+ webgpu_pipeline get_set_pipeline(const ggml_webgpu_shader_lib_context & context) {
+ ggml_webgpu_set_pipeline_key key = { .type = context.dst->type, .inplace = context.inplace };
+
+ auto it = set_pipelines.find(key);
+ if (it != set_pipelines.end()) {
+ return it->second;
+ }
+
+ std::vector<std::string> defines;
+ std::string variant = "set";
+
+ switch (key.type) {
+ case GGML_TYPE_F32:
+ defines.push_back("TYPE_F32");
+ variant += "_f32";
+ break;
+ case GGML_TYPE_I32:
+ defines.push_back("TYPE_I32");
+ variant += "_i32";
+ break;
+ default:
+ GGML_ABORT("Unsupported type for set shader");
+ }
+
+ if (key.inplace) {
+ defines.push_back("INPLACE");
+ variant += "_inplace";
+ }
+
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+
+ auto processed = preprocessor.preprocess(wgsl_set, defines);
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
+ decisions->wg_size = context.max_wg_size;
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
+ pipeline.context = decisions;
+ set_pipelines[key] = pipeline;
+ return set_pipelines[key];
+ }
+
webgpu_pipeline get_cumsum_pipeline(const ggml_webgpu_shader_lib_context & context) {
auto it = cumsum_pipelines.find(1);
if (it != cumsum_pipelines.end()) {
switch (key.src_type) {
case GGML_TYPE_F32:
+ defines.push_back("FLOAT_PARALLEL");
if (key.vectorized) {
defines.push_back("F32_VEC");
defines.push_back("SRC_TYPE=vec4<f32>");
variant += "_f32";
break;
case GGML_TYPE_F16:
+ defines.push_back("FLOAT_PARALLEL");
defines.push_back("F16");
defines.push_back("SRC_TYPE=f16");
defines.push_back("DST_TYPE=f32");
variant += "_f16";
break;
case GGML_TYPE_I32:
+ defines.push_back("FLOAT_PARALLEL");
defines.push_back("I32");
defines.push_back("SRC_TYPE=i32");
defines.push_back("DST_TYPE=i32");
return scale_pipelines[key];
}
+ webgpu_pipeline get_solve_tri_pipeline(const ggml_webgpu_shader_lib_context & context) {
+ ggml_webgpu_solve_tri_pipeline_key key = {
+ .type = context.dst->type,
+ .n = (int) context.src0->ne[0],
+ .k = (int) context.src1->ne[0],
+ };
+
+ auto it = solve_tri_pipelines.find(key);
+ if (it != solve_tri_pipelines.end()) {
+ return it->second;
+ }
+
+ std::vector<std::string> defines;
+ std::string variant = "solve_tri";
+
+ switch (key.type) {
+ case GGML_TYPE_F32:
+ variant += "_f32";
+ break;
+ default:
+ GGML_ABORT("Unsupported type for solve_tri shader");
+ }
+
+ const uint32_t wg_size = std::min((uint32_t) key.n, context.max_wg_size);
+ const uint32_t k_tile = wg_size;
+ const uint32_t bytes_per_row = ((uint32_t) key.n + wg_size) * GGML_WEBGPU_F32_SIZE_BYTES;
+ const uint32_t batch_n = (uint32_t) (context.wg_mem_limit_bytes / bytes_per_row);
+
+ defines.push_back(std::string("N=") + std::to_string(key.n));
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
+ defines.push_back(std::string("K_TILE=") + std::to_string(k_tile));
+ defines.push_back(std::string("BATCH_N=") + std::to_string(batch_n));
+
+ auto processed = preprocessor.preprocess(wgsl_solve_tri, defines);
+ auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
+ decisions->wg_size = wg_size;
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
+ pipeline.context = decisions;
+ solve_tri_pipelines[key] = pipeline;
+ return solve_tri_pipelines[key];
+ }
+
+ webgpu_pipeline get_ssm_conv_pipeline(const ggml_webgpu_shader_lib_context & context) {
+ ggml_webgpu_ssm_conv_pipeline_key key = {
+ .type = context.dst->type,
+ .vectorized = context.src1->ne[0] == 4,
+ };
+
+ auto it = ssm_conv_pipelines.find(key);
+ if (it != ssm_conv_pipelines.end()) {
+ return it->second;
+ }
+
+ std::vector<std::string> defines;
+ std::string variant = "ssm_conv";
+
+ switch (key.type) {
+ case GGML_TYPE_F32:
+ variant += "_f32";
+ break;
+ default:
+ GGML_ABORT("Unsupported type for ssm_conv shader");
+ }
+
+ if (key.vectorized) {
+ defines.push_back("VECTORIZED");
+ variant += "_vec4";
+ }
+
+ constexpr uint32_t block_size = 32u;
+ constexpr uint32_t tokens_per_wg = 8u;
+
+ defines.push_back("BLOCK_SIZE=" + std::to_string(block_size) + "u");
+ defines.push_back("TOKENS_PER_WG=" + std::to_string(tokens_per_wg) + "u");
+
+ auto processed = preprocessor.preprocess(wgsl_ssm_conv, defines);
+ auto decisions = std::make_shared<ggml_webgpu_ssm_conv_shader_decisions>();
+ decisions->block_size = block_size;
+ decisions->tokens_per_wg = tokens_per_wg;
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
+ pipeline.context = decisions;
+ ssm_conv_pipelines[key] = pipeline;
+ return ssm_conv_pipelines[key];
+ }
+
+ webgpu_pipeline get_gated_delta_net_pipeline(const ggml_webgpu_shader_lib_context & context) {
+ ggml_webgpu_gated_delta_net_pipeline_key key = {
+ .type = context.dst->type,
+ .s_v = (int) context.src2->ne[0],
+ .kda = context.src3->ne[0] == context.src2->ne[0],
+ };
+
+ auto it = gated_delta_net_pipelines.find(key);
+ if (it != gated_delta_net_pipelines.end()) {
+ return it->second;
+ }
+
+ std::vector<std::string> defines;
+ std::string variant = "gated_delta_net";
+
+ switch (key.type) {
+ case GGML_TYPE_F32:
+ variant += "_f32";
+ break;
+ default:
+ GGML_ABORT("Unsupported type for gated_delta_net shader");
+ }
+
+ if (key.kda) {
+ defines.push_back("KDA");
+ variant += "_kda";
+ }
+
+ defines.push_back("S_V=" + std::to_string(key.s_v) + "u");
+ defines.push_back("WG_SIZE=" + std::to_string(key.s_v) + "u");
+
+ auto processed = preprocessor.preprocess(wgsl_gated_delta_net, defines);
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
+ gated_delta_net_pipelines[key] = pipeline;
+ return gated_delta_net_pipelines[key];
+ }
+
webgpu_pipeline get_pad_pipeline(const ggml_webgpu_shader_lib_context & context) {
ggml_webgpu_pad_pipeline_key key = { .circular = ggml_get_op_params_i32(context.dst, 8) != 0 };
params, entries, wg_x);
}
+static webgpu_command ggml_webgpu_set(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) {
+ const bool inplace = ggml_webgpu_tensor_equal(src0, dst);
+
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
+ .src0 = src0,
+ .src1 = src1,
+ .dst = dst,
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+ .inplace = inplace,
+ };
+
+ webgpu_pipeline pipeline = ctx->shader_lib->get_set_pipeline(shader_lib_ctx);
+
+ auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
+
+ const uint32_t ne = inplace ? (uint32_t) ggml_nelements(src1) : (uint32_t) ggml_nelements(dst);
+ const uint32_t dst_type_size = (uint32_t) ggml_type_size(dst->type);
+
+ std::vector<uint32_t> params = {
+ ne,
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
+ (uint32_t) (((const int32_t *) dst->op_params)[3] / dst_type_size),
+
+ (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
+ (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
+ (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
+ (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
+
+ 1u,
+ (uint32_t) (((const int32_t *) dst->op_params)[0] / dst_type_size),
+ (uint32_t) (((const int32_t *) dst->op_params)[1] / dst_type_size),
+ (uint32_t) (((const int32_t *) dst->op_params)[2] / dst_type_size),
+
+ (uint32_t) src1->ne[0],
+ (uint32_t) src1->ne[1],
+ (uint32_t) src1->ne[2],
+ (uint32_t) src1->ne[3],
+ };
+
+ std::vector<wgpu::BindGroupEntry> entries;
+ uint32_t binding_index = 0;
+ if (!inplace) {
+ entries.push_back({ .binding = 0,
+ .buffer = ggml_webgpu_tensor_buf(src0),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0) });
+ binding_index++;
+ }
+ entries.push_back({ .binding = binding_index,
+ .buffer = ggml_webgpu_tensor_buf(src1),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src1) });
+ entries.push_back({ .binding = binding_index + 1,
+ .buffer = ggml_webgpu_tensor_buf(dst),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
+
+ uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
+}
+
static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
+static webgpu_command ggml_webgpu_solve_tri(webgpu_context & ctx,
+ ggml_tensor * src0,
+ ggml_tensor * src1,
+ ggml_tensor * dst) {
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
+ .src0 = src0,
+ .src1 = src1,
+ .dst = dst,
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+ .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
+ };
+
+ webgpu_pipeline pipeline = ctx->shader_lib->get_solve_tri_pipeline(shader_lib_ctx);
+
+ auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
+
+ std::vector<uint32_t> params = {
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+
+ (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)),
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
+
+ (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)),
+ (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
+ (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)),
+ (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)),
+
+ (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)),
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
+ (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)),
+
+ (uint32_t) src1->ne[0],
+ (uint32_t) dst->ne[2],
+ (uint32_t) dst->ne[3],
+ };
+
+ std::vector<wgpu::BindGroupEntry> entries = {
+ { .binding = 0,
+ .buffer = ggml_webgpu_tensor_buf(src0),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0) },
+ { .binding = 1,
+ .buffer = ggml_webgpu_tensor_buf(src1),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src1) },
+ { .binding = 2,
+ .buffer = ggml_webgpu_tensor_buf(dst),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
+ };
+
+ const uint32_t wg_x = CEIL_DIV((uint32_t) src1->ne[0], decisions->wg_size);
+ const uint32_t wg_y = (uint32_t) (dst->ne[2] * dst->ne[3]);
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
+}
+
+static webgpu_command ggml_webgpu_ssm_conv(webgpu_context & ctx,
+ ggml_tensor * src0,
+ ggml_tensor * src1,
+ ggml_tensor * dst) {
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
+ .src0 = src0,
+ .src1 = src1,
+ .dst = dst,
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+ };
+
+ webgpu_pipeline pipeline = ctx->shader_lib->get_ssm_conv_pipeline(shader_lib_ctx);
+ auto * decisions = static_cast<ggml_webgpu_ssm_conv_shader_decisions *>(pipeline.context.get());
+
+ const uint32_t token_tiles = CEIL_DIV((uint32_t) dst->ne[1], decisions->tokens_per_wg);
+
+ std::vector<uint32_t> params = {
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
+ (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)),
+
+ (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)),
+ (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)),
+ (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)),
+
+ (uint32_t) src1->ne[0],
+ (uint32_t) src0->ne[1],
+ (uint32_t) dst->ne[1],
+ (uint32_t) dst->ne[2],
+ token_tiles,
+ };
+
+ std::vector<wgpu::BindGroupEntry> entries = {
+ { .binding = 0,
+ .buffer = ggml_webgpu_tensor_buf(src0),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0) },
+ { .binding = 1,
+ .buffer = ggml_webgpu_tensor_buf(src1),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src1) },
+ { .binding = 2,
+ .buffer = ggml_webgpu_tensor_buf(dst),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
+ };
+
+ const uint32_t wg_x = CEIL_DIV((uint32_t) src0->ne[1], decisions->block_size);
+ const uint32_t wg_y = token_tiles * (uint32_t) dst->ne[2];
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
+}
+
+static webgpu_command ggml_webgpu_gated_delta_net(webgpu_context & ctx,
+ ggml_tensor * src0,
+ ggml_tensor * src1,
+ ggml_tensor * src2,
+ ggml_tensor * src3,
+ ggml_tensor * src4,
+ ggml_tensor * src5,
+ ggml_tensor * dst) {
+ ggml_webgpu_shader_lib_context shader_lib_ctx = {
+ .src0 = src0,
+ .src1 = src1,
+ .src2 = src2,
+ .src3 = src3,
+ .src4 = src4,
+ .dst = dst,
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+ };
+
+ webgpu_pipeline pipeline = ctx->shader_lib->get_gated_delta_net_pipeline(shader_lib_ctx);
+
+ const uint32_t s_v = (uint32_t) src2->ne[0];
+ const uint32_t h = (uint32_t) src2->ne[1];
+ const uint32_t n_tokens = (uint32_t) src2->ne[2];
+ const uint32_t n_seqs = (uint32_t) src2->ne[3];
+ const float scale = 1.0f / sqrtf((float) s_v);
+ uint32_t scale_u32;
+ memcpy(&scale_u32, &scale, sizeof(scale_u32));
+
+ std::vector<uint32_t> params = {
+ h,
+ n_tokens,
+ n_seqs,
+ s_v * h * n_tokens * n_seqs,
+
+ (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)),
+ (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)),
+ (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)),
+
+ (uint32_t) (src2->nb[1] / ggml_type_size(src2->type)),
+ (uint32_t) (src2->nb[2] / ggml_type_size(src2->type)),
+ (uint32_t) (src2->nb[3] / ggml_type_size(src2->type)),
+
+ (uint32_t) (src4->nb[1] / ggml_type_size(src4->type)),
+ (uint32_t) (src4->nb[2] / ggml_type_size(src4->type)),
+ (uint32_t) (src4->nb[3] / ggml_type_size(src4->type)),
+
+ (uint32_t) src0->ne[1],
+ (uint32_t) (src2->ne[3] / src0->ne[3]),
+ scale_u32,
+ };
+
+ std::vector<wgpu::BindGroupEntry> entries = {
+ { .binding = 0,
+ .buffer = ggml_webgpu_tensor_buf(src0),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src0),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src0) },
+ { .binding = 1,
+ .buffer = ggml_webgpu_tensor_buf(src1),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src1),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src1) },
+ { .binding = 2,
+ .buffer = ggml_webgpu_tensor_buf(src2),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src2),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src2) },
+ { .binding = 3,
+ .buffer = ggml_webgpu_tensor_buf(src3),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src3),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src3) },
+ { .binding = 4,
+ .buffer = ggml_webgpu_tensor_buf(src4),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src4),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src4) },
+ { .binding = 5,
+ .buffer = ggml_webgpu_tensor_buf(src5),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, src5),
+ .size = ggml_webgpu_tensor_binding_size(ctx, src5) },
+ { .binding = 6,
+ .buffer = ggml_webgpu_tensor_buf(dst),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) }
+ };
+
+ return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, h, n_seqs);
+}
+
static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
ggml_tensor * src,
ggml_tensor * idx,
ggml_tensor * src,
ggml_tensor * idx,
ggml_tensor * dst) {
+ const bool float_parallel = src->type == GGML_TYPE_F32 || src->type == GGML_TYPE_F16 || src->type == GGML_TYPE_I32;
+
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src,
.src1 = nullptr,
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
};
- uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], decisions->wg_size);
+ uint32_t blocks_per_row = (uint32_t) (dst->ne[0] / (src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0 ? 4 : 1));
+ uint32_t total_rows = (uint32_t) (dst->ne[1] * dst->ne[2] * dst->ne[3]);
+ uint32_t total_threads = float_parallel ? blocks_per_row * total_rows : total_rows;
+ uint32_t wg_x = CEIL_DIV(total_threads, decisions->wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
ggml_webgpu_shader_lib_context shader_lib_ctx = {
.src0 = src,
.dst = dst,
- .max_wg_size = WEBGPU_ROW_SPLIT_WG_SIZE,
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
.inplace = inplace,
};
case GGML_OP_CPY:
case GGML_OP_CONT:
return ggml_webgpu_cpy(ctx, src0, node);
+ case GGML_OP_SET:
+ return ggml_webgpu_set(ctx, src0, src1, node);
case GGML_OP_SET_ROWS:
return ggml_webgpu_set_rows(ctx, src0, src1, node);
case GGML_OP_GET_ROWS:
case GGML_OP_DIAG:
case GGML_OP_TRI:
return ggml_webgpu_unary_op(ctx, src0, node);
+ case GGML_OP_SOLVE_TRI:
+ return ggml_webgpu_solve_tri(ctx, src0, src1, node);
+ case GGML_OP_SSM_CONV:
+ return ggml_webgpu_ssm_conv(ctx, src0, src1, node);
+ case GGML_OP_GATED_DELTA_NET:
+ return ggml_webgpu_gated_delta_net(ctx, src0, src1, src2, node->src[3], node->src[4], node->src[5], node);
case GGML_OP_PAD:
return ggml_webgpu_pad(ctx, src0, node);
case GGML_OP_ARGMAX:
/* .is_host = */ NULL, // defaults to false
},
/* .device = */
- dev,
+ dev,
/* .context = */ NULL
};
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) ||
(op->type == GGML_TYPE_I32 && src0->type == GGML_TYPE_F32);
break;
+ case GGML_OP_SET:
+ supports_op = src0->type == src1->type && src0->type == op->type &&
+ (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_I32);
+ break;
case GGML_OP_SET_ROWS:
supports_op = ((op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32) && src0->type == GGML_TYPE_F32 &&
(src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32));
}
}
break;
+ case GGML_OP_TRI:
+ supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
+ break;
+ case GGML_OP_DIAG:
+ supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
+ break;
+ case GGML_OP_SOLVE_TRI:
+ supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
+ break;
+ case GGML_OP_SSM_CONV:
+ supports_op = op->type == GGML_TYPE_F32;
+ break;
+ case GGML_OP_GATED_DELTA_NET:
+ {
+ const uint32_t s_v = (uint32_t) src2->ne[0];
+ supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 &&
+ src2->type == GGML_TYPE_F32 && op->src[3]->type == GGML_TYPE_F32 &&
+ op->src[4]->type == GGML_TYPE_F32 && op->src[5]->type == GGML_TYPE_F32 &&
+ s_v <= ctx->webgpu_global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup;
+ }
+ break;
case GGML_OP_CLAMP:
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
break;
case GGML_OP_COS:
supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
break;
- case GGML_OP_DIAG:
- supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
- break;
- case GGML_OP_TRI:
- supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type);
- break;
case GGML_OP_PAD:
supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
break;
--- /dev/null
+@group(0) @binding(0)
+var<storage, read_write> src_q: array<f32>;
+
+@group(0) @binding(1)
+var<storage, read_write> src_k: array<f32>;
+
+@group(0) @binding(2)
+var<storage, read_write> src_v: array<f32>;
+
+@group(0) @binding(3)
+var<storage, read_write> src_g: array<f32>;
+
+@group(0) @binding(4)
+var<storage, read_write> src_beta: array<f32>;
+
+@group(0) @binding(5)
+var<storage, read_write> src_state: array<f32>;
+
+@group(0) @binding(6)
+var<storage, read_write> dst: array<f32>;
+
+struct Params {
+ h: u32,
+ n_tokens: u32,
+ n_seqs: u32,
+ s_off: u32,
+
+ sq1: u32,
+ sq2: u32,
+ sq3: u32,
+
+ sv1: u32,
+ sv2: u32,
+ sv3: u32,
+
+ sb1: u32,
+ sb2: u32,
+ sb3: u32,
+
+ neq1: u32,
+ rq3: u32,
+ scale: f32,
+};
+
+@group(0) @binding(7)
+var<uniform> params: Params;
+
+var<workgroup> sh_k: array<f32, S_V>;
+var<workgroup> sh_q: array<f32, S_V>;
+var<workgroup> sh_g: array<f32, S_V>;
+
+@compute @workgroup_size(WG_SIZE)
+fn main(
+ @builtin(workgroup_id) workgroup_id: vec3<u32>,
+ @builtin(local_invocation_id) local_id: vec3<u32>
+) {
+ let head_id = workgroup_id.x;
+ let seq_id = workgroup_id.y;
+ let col = local_id.x;
+
+ let iq1 = head_id % params.neq1;
+ let iq3 = seq_id / params.rq3;
+
+ let state_size = S_V * S_V;
+ let state_base = (seq_id * params.h + head_id) * state_size;
+
+ var state: array<f32, S_V>;
+ for (var i = 0u; i < S_V; i++) {
+ state[i] = src_state[state_base + col * S_V + i];
+ }
+
+ var attn_off = (seq_id * params.n_tokens * params.h + head_id) * S_V;
+
+ for (var t = 0u; t < params.n_tokens; t++) {
+ let q_off = iq3 * params.sq3 + t * params.sq2 + iq1 * params.sq1;
+ let k_off = q_off;
+ let v_off = seq_id * params.sv3 + t * params.sv2 + head_id * params.sv1;
+ let gb_off = seq_id * params.sb3 + t * params.sb2 + head_id * params.sb1;
+
+ sh_q[col] = src_q[q_off + col];
+ sh_k[col] = src_k[k_off + col];
+
+#ifdef KDA
+ let g_base = gb_off * S_V;
+ sh_g[col] = exp(src_g[g_base + col]);
+#endif
+
+ workgroupBarrier();
+
+ let v_val = src_v[v_off + col];
+ let beta_val = src_beta[gb_off];
+
+ var kv_col = 0.0;
+ var delta_col = 0.0;
+ var attn_col = 0.0;
+
+#ifdef KDA
+ for (var i = 0u; i < S_V; i++) {
+ kv_col += (sh_g[i] * state[i]) * sh_k[i];
+ }
+
+ delta_col = (v_val - kv_col) * beta_val;
+
+ for (var i = 0u; i < S_V; i++) {
+ state[i] = sh_g[i] * state[i] + sh_k[i] * delta_col;
+ attn_col += state[i] * sh_q[i];
+ }
+#else
+ let g_val = exp(src_g[gb_off]);
+
+ for (var i = 0u; i < S_V; i++) {
+ kv_col += state[i] * sh_k[i];
+ }
+
+ delta_col = (v_val - g_val * kv_col) * beta_val;
+
+ for (var i = 0u; i < S_V; i++) {
+ state[i] = g_val * state[i] + sh_k[i] * delta_col;
+ attn_col += state[i] * sh_q[i];
+ }
+#endif
+
+ dst[attn_off + col] = attn_col * params.scale;
+ attn_off += S_V * params.h;
+
+ workgroupBarrier();
+ }
+
+ for (var i = 0u; i < S_V; i++) {
+ dst[params.s_off + state_base + col * S_V + i] = state[i];
+ }
+}
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
+#ifdef FLOAT_PARALLEL
+ let blocks_per_row = params.ne0 / BLOCK_SIZE;
+ let row_count = params.n_rows * params.ne2 * params.ne3;
+
+ if (gid.x >= blocks_per_row * row_count) {
+ return;
+ }
+
+ let block_idx = gid.x % blocks_per_row;
+ var row_idx = gid.x / blocks_per_row;
+ let i_dst3 = row_idx / (params.ne2 * params.n_rows);
+
+ row_idx = row_idx % (params.ne2 * params.n_rows);
+ let i_dst2 = row_idx / params.n_rows;
+ let i_dst1 = row_idx % params.n_rows;
+
+ let i_idx2 = i_dst3 % params.idx2;
+ let i_idx1 = i_dst2 % params.idx1;
+ let i_idx0 = i_dst1;
+
+ let i_idx = params.offset_idx + i_idx0 * params.stride_idx0 + i_idx1 * params.stride_idx1 + i_idx2 * params.stride_idx2;
+
+ let idx_val = u32(idx[i_idx]);
+
+ let i_src_row = params.offset_src + idx_val * params.stride_src1 + i_dst2 * params.stride_src2 + i_dst3 * params.stride_src3;
+ let i_dst_row = params.offset_dst + i_dst1 * params.stride_dst1 + i_dst2 * params.stride_dst2 + i_dst3 * params.stride_dst3;
+
+ copy_elements(i_src_row, i_dst_row, block_idx);
+#else
if (gid.x >= params.n_rows * params.ne2 * params.ne3) {
return;
}
for (var i: u32 = 0; i < params.ne0/BLOCK_SIZE; i++) {
copy_elements(i_src_row, i_dst_row, i);
}
+#endif
}
-
}
sum = scratch[0];
-#ifdef OP_RMS_NORM
+#ifdef RMS_NORM
let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps);
-#elif OP_L2_NORM
+#elif defined(L2_NORM)
let scale = 1.0/max(sqrt(sum), params.eps);
#endif
+
col = lid.x;
for (var j: u32 = 0; j < elems; j++) {
if (col >= params.ne0) {
--- /dev/null
+#ifdef TYPE_I32
+#define TYPE i32
+#else
+#define TYPE f32
+#endif
+
+#ifndef INPLACE
+@group(0) @binding(0)
+var<storage, read_write> src0: array<TYPE>;
+#define SRC1_BINDING 1
+#else
+#define SRC1_BINDING 0
+#endif
+
+#define DST_BINDING SRC1_BINDING + 1
+#define PARAMS_BINDING SRC1_BINDING + 2
+
+@group(0) @binding(SRC1_BINDING)
+var<storage, read_write> src1: array<TYPE>;
+
+@group(0) @binding(DST_BINDING)
+var<storage, read_write> dst: array<TYPE>;
+
+struct Params {
+ ne: u32,
+ offset_src0: u32,
+ offset_src1: u32,
+ offset_view: u32,
+
+ stride_src10: u32,
+ stride_src11: u32,
+ stride_src12: u32,
+ stride_src13: u32,
+
+ stride_dst10: u32,
+ stride_dst11: u32,
+ stride_dst12: u32,
+ stride_dst13: u32,
+
+ src1_ne0: u32,
+ src1_ne1: u32,
+ src1_ne2: u32,
+ src1_ne3: u32,
+};
+
+@group(0) @binding(PARAMS_BINDING)
+var<uniform> params: Params;
+
+fn decode_src1_coords(idx: u32) -> vec4<u32> {
+ var i = idx;
+ let plane = params.src1_ne2 * params.src1_ne1 * params.src1_ne0;
+ let i3 = i / plane;
+ i = i % plane;
+ let row = params.src1_ne1 * params.src1_ne0;
+ let i2 = i / row;
+ i = i % row;
+ let i1 = i / params.src1_ne0;
+ let i0 = i % params.src1_ne0;
+ return vec4<u32>(i0, i1, i2, i3);
+}
+
+fn decode_view_coords(rel: u32) -> vec4<u32> {
+ let i3 = rel / params.stride_dst13;
+ let rem3 = rel % params.stride_dst13;
+ let i2 = rem3 / params.stride_dst12;
+ let rem2 = rem3 % params.stride_dst12;
+ let i1 = rem2 / params.stride_dst11;
+ let i0 = rem2 % params.stride_dst11;
+ return vec4<u32>(i0, i1, i2, i3);
+}
+
+fn view_rel_from_coords(coords: vec4<u32>) -> u32 {
+ return coords.x * params.stride_dst10 + coords.y * params.stride_dst11 +
+ coords.z * params.stride_dst12 + coords.w * params.stride_dst13;
+}
+
+fn src1_idx_from_coords(coords: vec4<u32>) -> u32 {
+ return coords.x * params.stride_src10 + coords.y * params.stride_src11 +
+ coords.z * params.stride_src12 + coords.w * params.stride_src13;
+}
+
+fn in_set_view(rel: u32, coords: vec4<u32>) -> bool {
+ return view_rel_from_coords(coords) == rel;
+}
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
+ if (gid.x >= params.ne) {
+ return;
+ }
+
+#ifdef INPLACE
+ let coords = decode_src1_coords(gid.x);
+
+ let src1_idx = params.offset_src1 + src1_idx_from_coords(coords);
+ let dst_idx = params.offset_view + view_rel_from_coords(coords);
+
+ dst[dst_idx] = src1[src1_idx];
+#else
+ let rel = select(params.ne, gid.x - params.offset_view, gid.x >= params.offset_view);
+ let coords = decode_view_coords(rel);
+
+ if (rel < params.stride_dst13 * params.src1_ne3 && in_set_view(rel, coords)) {
+ dst[gid.x] = src1[params.offset_src1 + src1_idx_from_coords(coords)];
+ } else {
+ dst[gid.x] = src0[params.offset_src0 + gid.x];
+ }
+#endif
+}
--- /dev/null
+@group(0) @binding(0)
+var<storage, read_write> src0: array<f32>;
+
+@group(0) @binding(1)
+var<storage, read_write> src1: array<f32>;
+
+@group(0) @binding(2)
+var<storage, read_write> dst: array<f32>;
+
+struct Params {
+ offset_src0: u32,
+ offset_src1: u32,
+ offset_dst: u32,
+
+ stride_src00: u32,
+ stride_src01: u32,
+ stride_src02: u32,
+ stride_src03: u32,
+
+ stride_src10: u32,
+ stride_src11: u32,
+ stride_src12: u32,
+ stride_src13: u32,
+
+ stride_dst0: u32,
+ stride_dst1: u32,
+ stride_dst2: u32,
+ stride_dst3: u32,
+
+ k: u32,
+ ne2: u32,
+ ne3: u32,
+};
+
+@group(0) @binding(3)
+var<uniform> params: Params;
+
+var<workgroup> shA: array<f32, BATCH_N * N>;
+var<workgroup> shB: array<f32, BATCH_N * K_TILE>;
+
+fn src0_idx(row: u32, col: u32, i2: u32, i3: u32) -> u32 {
+ return params.offset_src0 +
+ col * params.stride_src00 +
+ row * params.stride_src01 +
+ i2 * params.stride_src02 +
+ i3 * params.stride_src03;
+}
+
+fn src1_idx(row: u32, col: u32, i2: u32, i3: u32) -> u32 {
+ return params.offset_src1 +
+ col * params.stride_src10 +
+ row * params.stride_src11 +
+ i2 * params.stride_src12 +
+ i3 * params.stride_src13;
+}
+
+fn dst_idx(row: u32, col: u32, i2: u32, i3: u32) -> u32 {
+ return params.offset_dst +
+ col * params.stride_dst0 +
+ row * params.stride_dst1 +
+ i2 * params.stride_dst2 +
+ i3 * params.stride_dst3;
+}
+
+@compute @workgroup_size(WG_SIZE)
+fn main(
+ @builtin(workgroup_id) workgroup_id: vec3<u32>,
+ @builtin(local_invocation_id) local_id: vec3<u32>
+) {
+ let batch = workgroup_id.y;
+ let col = workgroup_id.x * WG_SIZE + local_id.x;
+ let i3 = batch / params.ne2;
+ let i2 = batch % params.ne2;
+ let active_lane = local_id.x < K_TILE;
+ let active_col = active_lane && col < params.k;
+
+ var X: array<f32, N>;
+
+ for (var row_base = 0u; row_base < N; row_base += BATCH_N) {
+ let cur_n = min(BATCH_N, N - row_base);
+
+ for (var i = local_id.x; i < cur_n * N; i += WG_SIZE) {
+ let tile_row = i / N;
+ let tile_col = i % N;
+ shA[i] = src0[src0_idx(row_base + tile_row, tile_col, i2, i3)];
+ }
+
+ for (var i = local_id.x; i < cur_n * K_TILE; i += WG_SIZE) {
+ let tile_row = i / K_TILE;
+ let tile_col = i % K_TILE;
+ let global_col = workgroup_id.x * WG_SIZE + tile_col;
+ let sh_idx = tile_row * K_TILE + tile_col;
+
+ if (global_col < params.k) {
+ shB[sh_idx] = src1[src1_idx(row_base + tile_row, global_col, i2, i3)];
+ } else {
+ shB[sh_idx] = 0.0;
+ }
+ }
+
+ workgroupBarrier();
+
+ if (active_col) {
+ for (var row_offset = 0u; row_offset < cur_n; row_offset++) {
+ let r = row_base + row_offset;
+ var b = shB[row_offset * K_TILE + local_id.x];
+ let a_row = row_offset * N;
+
+ for (var t = 0u; t < r; t++) {
+ b -= shA[a_row + t] * X[t];
+ }
+
+ let x = b / shA[a_row + r];
+ X[r] = x;
+ dst[dst_idx(r, col, i2, i3)] = x;
+ }
+ }
+
+ workgroupBarrier();
+ }
+}
--- /dev/null
+@group(0) @binding(0)
+var<storage, read_write> src0: array<f32>;
+
+@group(0) @binding(1)
+var<storage, read_write> src1: array<f32>;
+
+@group(0) @binding(2)
+var<storage, read_write> dst: array<f32>;
+
+struct Params {
+ offset_src0: u32,
+ offset_src1: u32,
+ offset_dst: u32,
+
+ stride_src01: u32,
+ stride_src02: u32,
+ stride_src11: u32,
+
+ stride_dst0: u32,
+ stride_dst1: u32,
+ stride_dst2: u32,
+
+ nc: u32,
+ nr: u32,
+ n_t: u32,
+ n_s: u32,
+ token_tiles: u32,
+};
+
+@group(0) @binding(3)
+var<uniform> params: Params;
+
+@compute @workgroup_size(BLOCK_SIZE, TOKENS_PER_WG)
+fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
+ let i1 = gid.x;
+ let tile_y = gid.y / TOKENS_PER_WG;
+ let local_token = gid.y % TOKENS_PER_WG;
+ let i3 = tile_y / params.token_tiles;
+ let token_tile = tile_y % params.token_tiles;
+ let i2 = token_tile * TOKENS_PER_WG + local_token;
+
+ if (i1 >= params.nr || i2 >= params.n_t || i3 >= params.n_s) {
+ return;
+ }
+
+ let src0_base = params.offset_src0 + i3 * params.stride_src02 + i2 + i1 * params.stride_src01;
+ let src1_base = params.offset_src1 + i1 * params.stride_src11;
+
+ var sum = 0.0;
+
+#ifdef VECTORIZED
+ sum =
+ src0[src0_base + 0u] * src1[src1_base + 0u] +
+ src0[src0_base + 1u] * src1[src1_base + 1u] +
+ src0[src0_base + 2u] * src1[src1_base + 2u] +
+ src0[src0_base + 3u] * src1[src1_base + 3u];
+#else
+ for (var i0 = 0u; i0 < params.nc; i0++) {
+ sum += src0[src0_base + i0] * src1[src1_base + i0];
+ }
+#endif
+
+ let dst_idx = params.offset_dst + i3 * params.stride_dst2 + i2 * params.stride_dst1 + i1 * params.stride_dst0;
+ dst[dst_idx] = sum;
+}