uint32_t wg_size = 0;
};
+struct ggml_webgpu_processed_shader {
+ std::string wgsl;
+ std::string variant;
+ std::shared_ptr<void> decisions;
+};
+
struct ggml_webgpu_ssm_conv_shader_decisions {
uint32_t block_size;
uint32_t tokens_per_wg;
bool has_mask;
bool has_sinks;
bool uses_logit_softcap;
+ bool use_vec;
bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const {
return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v &&
kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks &&
- uses_logit_softcap == other.uses_logit_softcap;
+ uses_logit_softcap == other.uses_logit_softcap && use_vec == other.use_vec;
}
};
ggml_webgpu_hash_combine(seed, key.has_mask);
ggml_webgpu_hash_combine(seed, key.has_sinks);
ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
+ ggml_webgpu_hash_combine(seed, key.use_vec);
return seed;
}
};
uint32_t wg_size = 0;
};
+inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_pipeline_key & key) {
+ // Keep conservative defaults unless this is the f16 vec-split shape family.
+ if (key.kv_type != GGML_TYPE_F16 || key.head_dim_qk != key.head_dim_v) {
+ return 1u;
+ }
+
+ // Head-dim specializations used by the tuned vec f16 path.
+ switch (key.head_dim_qk) {
+ case 64: return 2u;
+ case 96: return 4u;
+ case 128: return 1u;
+ case 192: return 2u;
+ case 576: return 2u;
+ default: return 1u;
+ }
+}
+
+struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key {
+ uint32_t head_dim_v;
+ uint32_t wg_size;
+};
+
+struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash {
+ size_t operator()(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & key) const {
+ size_t seed = 0;
+ ggml_webgpu_hash_combine(seed, key.head_dim_v);
+ ggml_webgpu_hash_combine(seed, key.wg_size);
+ return seed;
+ }
+};
+
+inline bool operator==(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & lhs,
+ const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & rhs) {
+ return lhs.head_dim_v == rhs.head_dim_v && lhs.wg_size == rhs.wg_size;
+}
+
+struct ggml_webgpu_flash_attn_vec_reduce_shader_lib_context {
+ ggml_webgpu_flash_attn_vec_reduce_pipeline_key key;
+ uint32_t max_wg_size;
+};
+
+inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_vec_reduce_shader(
+ pre_wgsl::Preprocessor & preprocessor,
+ const char * shader_src,
+ const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) {
+ std::vector<std::string> defines;
+ std::string variant = "flash_attn_vec_reduce";
+
+ defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v));
+ variant += std::string("_hsv") + std::to_string(context.key.head_dim_v);
+
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+ variant += std::string("_wg") + std::to_string(context.max_wg_size);
+
+ ggml_webgpu_processed_shader result;
+ result.wgsl = preprocessor.preprocess(shader_src, defines);
+ result.variant = variant;
+ return result;
+}
+
+struct ggml_webgpu_flash_attn_blk_pipeline_key {
+ uint32_t q_tile;
+ uint32_t kv_tile;
+
+ bool operator==(const ggml_webgpu_flash_attn_blk_pipeline_key & other) const {
+ return q_tile == other.q_tile && kv_tile == other.kv_tile;
+ }
+};
+
+struct ggml_webgpu_flash_attn_blk_pipeline_key_hash {
+ size_t operator()(const ggml_webgpu_flash_attn_blk_pipeline_key & key) const {
+ size_t seed = 0;
+ ggml_webgpu_hash_combine(seed, key.q_tile);
+ ggml_webgpu_hash_combine(seed, key.kv_tile);
+ return seed;
+ }
+};
+
+struct ggml_webgpu_flash_attn_blk_shader_lib_context {
+ ggml_webgpu_flash_attn_blk_pipeline_key key;
+ uint32_t max_wg_size;
+};
+
+inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_blk_shader(
+ pre_wgsl::Preprocessor & preprocessor,
+ const char * shader_src,
+ const ggml_webgpu_flash_attn_blk_shader_lib_context & context) {
+ std::vector<std::string> defines;
+ std::string variant = "flash_attn_vec_blk";
+
+ defines.push_back(std::string("Q_TILE=") + std::to_string(context.key.q_tile));
+ variant += std::string("_qt") + std::to_string(context.key.q_tile);
+
+ defines.push_back(std::string("KV_TILE=") + std::to_string(context.key.kv_tile));
+ variant += std::string("_kvt") + std::to_string(context.key.kv_tile);
+
+ uint32_t wg_size = 1;
+ while ((wg_size << 1) <= context.max_wg_size) {
+ wg_size <<= 1;
+ }
+ defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
+ variant += std::string("_wg") + std::to_string(wg_size);
+
+ ggml_webgpu_processed_shader result;
+ result.wgsl = preprocessor.preprocess(shader_src, defines);
+ result.variant = variant;
+ return result;
+}
+
// This is exposed because it's necessary in supports_op
inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
uint32_t kv_tile,
repeat_pipelines; // type
std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
flash_attn_pipelines;
+ std::unordered_map<ggml_webgpu_flash_attn_vec_reduce_pipeline_key,
+ webgpu_pipeline,
+ ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash>
+ flash_attn_vec_reduce_pipelines;
+ std::unordered_map<ggml_webgpu_flash_attn_blk_pipeline_key,
+ webgpu_pipeline,
+ ggml_webgpu_flash_attn_blk_pipeline_key_hash>
+ flash_attn_blk_pipelines;
std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key,
webgpu_pipeline,
ggml_webgpu_legacy_mul_mat_pipeline_key_hash>
return repeat_pipelines[key];
}
- webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) {
- const bool has_mask = context.src3 != nullptr;
- const bool has_sinks = context.src4 != nullptr;
-
- bool kv_direct = (context.src1->type == GGML_TYPE_F16) && (context.src0->ne[0] % context.sg_mat_k == 0) &&
- (context.src1->ne[1] % context.sg_mat_n == 0);
-
- ggml_webgpu_flash_attn_pipeline_key key = {
- .kv_type = context.src1->type,
- .head_dim_qk = (uint32_t) context.src0->ne[0],
- .head_dim_v = (uint32_t) context.src2->ne[0],
- .kv_direct = kv_direct,
- .has_mask = has_mask,
- .has_sinks = has_sinks,
- .uses_logit_softcap = (*(float *) &context.dst->op_params[2]) != 0.0f,
- };
-
- auto it = flash_attn_pipelines.find(key);
+ webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_flash_attn_shader_lib_context & context) {
+ auto it = flash_attn_pipelines.find(context.key);
if (it != flash_attn_pipelines.end()) {
return it->second;
}
std::vector<std::string> defines;
std::string variant = "flash_attn";
- switch (key.kv_type) {
+ switch (context.key.kv_type) {
case GGML_TYPE_F32:
defines.push_back("KV_F32");
break;
default:
GGML_ABORT("Unsupported KV type for flash attention shader");
}
- variant += std::string("_") + ggml_type_name(key.kv_type);
+ variant += std::string("_") + ggml_type_name(context.key.kv_type);
- if (key.has_mask) {
+ if (context.key.has_mask) {
defines.push_back("MASK");
variant += "_mask";
}
- if (key.has_sinks) {
+ if (context.key.has_sinks) {
defines.push_back("SINKS");
variant += "_sinks";
}
- if (key.uses_logit_softcap) {
+ if (context.key.uses_logit_softcap) {
defines.push_back("LOGIT_SOFTCAP");
variant += "_lgsc";
}
- if (key.kv_direct) {
+ if (context.key.kv_direct) {
defines.push_back("KV_DIRECT");
variant += "_kvdirect";
}
+ if (context.key.has_mask && context.key.use_vec) {
+ defines.push_back("BLK");
+ variant += "_blk";
+ }
- defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk));
- variant += std::string("_hsqk") + std::to_string(key.head_dim_qk);
+ defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.key.head_dim_qk));
+ variant += std::string("_hsqk") + std::to_string(context.key.head_dim_qk);
- defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
- variant += std::string("_hsv") + std::to_string(key.head_dim_v);
+ defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v));
+ variant += std::string("_hsv") + std::to_string(context.key.head_dim_v);
defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
- uint32_t q_tile = context.sg_mat_m;
+ uint32_t q_tile = context.sg_mat_m;
uint32_t kv_tile =
- std::min(ggml_webgpu_flash_attn_max_kv_tile({ key, context.sg_mat_m, context.sg_mat_n, context.sg_mat_k,
- context.wg_mem_limit_bytes, context.max_subgroup_size }),
+ std::min(ggml_webgpu_flash_attn_max_kv_tile(context),
context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
- if (key.kv_direct) {
+ if (context.key.use_vec) {
+ q_tile = 1;
+ kv_tile = std::max(context.sg_mat_n, std::min(32u, ggml_webgpu_flash_attn_max_kv_tile(context)));
+ kv_tile = (kv_tile / context.sg_mat_n) * context.sg_mat_n;
+ const uint32_t vec_ne = ggml_webgpu_flash_attn_pick_vec_ne(context.key);
+ defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u");
+ }
+ if (context.key.kv_direct) {
+ GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
kv_tile -= context.sg_mat_n;
}
defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile));
defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile));
- uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
+ uint32_t wg_size = 0;
+ if (context.key.use_vec) {
+ wg_size = std::max(1u, std::min<uint32_t>(32u, context.max_subgroup_size));
+ } else {
+ wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
+ }
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
- auto processed = preprocessor.preprocess(wgsl_flash_attn, defines);
+ const char * shader_src = context.key.use_vec ? wgsl_flash_attn_vec_split : wgsl_flash_attn;
+ webgpu_pipeline pipeline =
+ ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant);
auto decisions = std::make_shared<ggml_webgpu_flash_attn_shader_decisions>();
decisions->q_tile = q_tile;
decisions->kv_tile = kv_tile;
decisions->wg_size = wg_size;
+ pipeline.context = decisions;
+ flash_attn_pipelines[context.key] = pipeline;
+ return flash_attn_pipelines[context.key];
+ }
+
+ webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_flash_attn_blk_shader_lib_context & context) {
+ auto it = flash_attn_blk_pipelines.find(context.key);
+ if (it != flash_attn_blk_pipelines.end()) {
+ return it->second;
+ }
+
+ ggml_webgpu_processed_shader processed =
+ ggml_webgpu_preprocess_flash_attn_blk_shader(preprocessor, wgsl_flash_attn_vec_blk, context);
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant);
+ flash_attn_blk_pipelines[context.key] = pipeline;
+ return flash_attn_blk_pipelines[context.key];
+ }
+
+ webgpu_pipeline get_flash_attn_vec_reduce_pipeline(
+ const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) {
+ auto it = flash_attn_vec_reduce_pipelines.find(context.key);
+ if (it != flash_attn_vec_reduce_pipelines.end()) {
+ return it->second;
+ }
- webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant);
- pipeline.context = decisions;
- flash_attn_pipelines[key] = pipeline;
- return flash_attn_pipelines[key];
+ ggml_webgpu_processed_shader processed =
+ ggml_webgpu_preprocess_flash_attn_vec_reduce_shader(preprocessor, wgsl_flash_attn_vec_reduce, context);
+ webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant);
+ flash_attn_vec_reduce_pipelines[context.key] = pipeline;
+ return flash_attn_vec_reduce_pipelines[context.key];
}
webgpu_pipeline get_cpy_pipeline(const ggml_webgpu_shader_lib_context & context) {
for (size_t i = 0; i < params_bufs_list.size(); i++) {
ctx->queue.WriteBuffer(params_bufs_list[i], 0, params_list[i].data(), params_list[i].size() * sizeof(uint32_t));
}
-
#ifdef GGML_WEBGPU_GPU_PROFILE
webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs();
if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
}
-#ifndef __EMSCRIPTEN__
static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
ggml_tensor * Q,
ggml_tensor * K,
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
- ggml_webgpu_shader_lib_context shader_lib_ctx = {
- .src0 = Q,
- .src1 = K,
- .src2 = V,
- .src3 = mask,
- .src4 = sinks,
- .dst = dst,
- .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
- .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
+ const uint32_t k_offset_elems = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type));
+ const uint32_t v_offset_elems = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type));
+ const bool f16_vec4_aligned = (k_offset_elems % 4u == 0u) && (v_offset_elems % 4u == 0u);
+
+ const bool kv_direct = (K->type == GGML_TYPE_F16) && f16_vec4_aligned &&
+ (Q->ne[0] % ctx->global_ctx->capabilities.sg_mat_k == 0) &&
+ (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
+
+ const bool kv_vec_type_supported =
+ K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0;
+ const bool use_vec = (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported &&
+ (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type);
+ const uint32_t vec_nwg_cap =
+ std::max(1u, std::min<uint32_t>(32u, ctx->global_ctx->capabilities.max_subgroup_size));
+ const bool use_blk = use_vec && has_mask;
+
+ ggml_webgpu_flash_attn_pipeline_key key = {
+ .kv_type = K->type,
+ .head_dim_qk = (uint32_t) Q->ne[0],
+ .head_dim_v = (uint32_t) V->ne[0],
+ .kv_direct = kv_direct,
+ .has_mask = static_cast<bool>(has_mask),
+ .has_sinks = static_cast<bool>(has_sinks),
+ .uses_logit_softcap = logit_softcap != 0.0f,
+ .use_vec = use_vec,
+ };
+
+ ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = {
+ .key = key,
.sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m,
.sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n,
.sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k,
+ .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
.max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size,
};
-
webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx);
auto * decisions = static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context.get());
uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile);
uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches
+
+ wgpu::Buffer blk_buf = {};
+ uint64_t blk_size_bytes = 0;
+ uint32_t blk_nblk0 = 0;
+ uint32_t blk_nblk1 = 0;
+ uint32_t blk_batch_count = 0;
+
+ if (use_vec) {
+ uint32_t nwg = 1u;
+ const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile);
+ while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) {
+ nwg <<= 1;
+ }
+ nwg = std::min(nwg, vec_nwg_cap);
+ GGML_ASSERT(nwg <= ctx->global_ctx->capabilities.max_subgroup_size);
+ const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3];
+ const bool use_vec_reduce = nwg > 1u;
+ GGML_ASSERT(nrows <= UINT32_MAX);
+
+ uint64_t tmp_stats_base = 0;
+ uint64_t tmp_size_bytes = 0;
+ wgpu::Buffer tmp_buf = {};
+ uint64_t tmp_bind_offset = 0;
+ uint64_t tmp_bind_size = 0;
+ const size_t align_bytes = ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment;
+ const size_t dst_offset = ggml_webgpu_tensor_offset(dst);
+ size_t scratch_offset = ROUNDUP_POW2(dst_offset + ggml_nbytes(dst), align_bytes);
+
+ if (use_vec_reduce) {
+ const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg;
+ const uint64_t tmp_stats_elems = nrows * 2u * nwg;
+ tmp_stats_base = tmp_data_elems;
+ tmp_size_bytes =
+ ROUNDUP_POW2((tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT);
+ GGML_ASSERT(tmp_stats_base <= UINT32_MAX);
+ tmp_buf = ggml_webgpu_tensor_buf(dst);
+ tmp_bind_offset = scratch_offset;
+ tmp_bind_size = tmp_size_bytes;
+ scratch_offset = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes);
+ } else {
+ // nwg==1 writes final dst directly in vec-split; keep tmp binding valid without extra allocation.
+ tmp_buf = ggml_webgpu_tensor_buf(dst);
+ tmp_bind_offset = ggml_webgpu_tensor_align_offset(ctx, dst);
+ tmp_bind_size = ggml_webgpu_tensor_binding_size(ctx, dst);
+ }
+
+ webgpu_pipeline blk_pipeline;
+ std::vector<uint32_t> blk_params;
+ std::vector<wgpu::BindGroupEntry> blk_entries;
+ if (use_blk) {
+ GGML_ASSERT(has_mask);
+
+ blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile);
+ blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], decisions->q_tile);
+ blk_buf = ggml_webgpu_tensor_buf(dst);
+ const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type));
+ blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u;
+ const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count;
+ blk_size_bytes = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT);
+ ggml_webgpu_flash_attn_blk_shader_lib_context blk_shader_ctx = {
+ .key =
+ {
+ .q_tile = decisions->q_tile,
+ .kv_tile = decisions->kv_tile,
+ },
+ .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+ };
+ blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx);
+
+ blk_params = {
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)), // offset_mask
+ (uint32_t) Q->ne[1], // seq_len_q
+ (uint32_t) K->ne[1], // seq_len_kv
+ stride_mask3, // stride_mask3
+ blk_nblk0, // nblk0
+ blk_nblk1, // nblk1
+ };
+ blk_entries = {
+ { .binding = 0,
+ .buffer = ggml_webgpu_tensor_buf(mask),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, mask),
+ .size = ggml_webgpu_tensor_binding_size(ctx, mask) },
+ { .binding = 1, .buffer = blk_buf, .offset = scratch_offset, .size = blk_size_bytes },
+ };
+ scratch_offset = ROUNDUP_POW2(scratch_offset + blk_size_bytes, align_bytes);
+ }
+
+ std::vector<uint32_t> split_params = params;
+ if (use_blk) {
+ split_params.push_back(0u); // blk_base
+ split_params.push_back(blk_nblk0); // blk_nblk0
+ split_params.push_back(blk_nblk1); // blk_nblk1
+ }
+ split_params.push_back(0u); // tmp_data_base
+ split_params.push_back((uint32_t) tmp_stats_base); // tmp_stats_base
+ split_params.push_back(nwg); // nwg
+
+ std::vector<wgpu::BindGroupEntry> split_entries = {
+ { .binding = 0,
+ .buffer = ggml_webgpu_tensor_buf(Q),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, Q),
+ .size = ggml_webgpu_tensor_binding_size(ctx, Q) },
+ { .binding = 1,
+ .buffer = ggml_webgpu_tensor_buf(K),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, K),
+ .size = ggml_webgpu_tensor_binding_size(ctx, K) },
+ { .binding = 2,
+ .buffer = ggml_webgpu_tensor_buf(V),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, V),
+ .size = ggml_webgpu_tensor_binding_size(ctx, V) },
+ };
+ uint32_t split_binding_index = 3;
+ if (has_mask) {
+ split_entries.push_back({ .binding = split_binding_index++,
+ .buffer = ggml_webgpu_tensor_buf(mask),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, mask),
+ .size = ggml_webgpu_tensor_binding_size(ctx, mask) });
+ }
+ if (has_sinks) {
+ split_entries.push_back({ .binding = split_binding_index++,
+ .buffer = ggml_webgpu_tensor_buf(sinks),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, sinks),
+ .size = ggml_webgpu_tensor_binding_size(ctx, sinks) });
+ }
+ if (use_blk) {
+ split_entries.push_back(
+ { .binding = split_binding_index++, .buffer = blk_buf, .offset = blk_entries[1].offset, .size = blk_size_bytes });
+ }
+ split_entries.push_back(
+ { .binding = split_binding_index++, .buffer = tmp_buf, .offset = tmp_bind_offset, .size = tmp_bind_size });
+ split_entries.push_back({ .binding = split_binding_index++,
+ .buffer = ggml_webgpu_tensor_buf(dst),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) });
+
+ webgpu_pipeline reduce_pipeline;
+ std::vector<uint32_t> reduce_params;
+ std::vector<wgpu::BindGroupEntry> reduce_entries;
+ if (use_vec_reduce) {
+ const uint32_t reduce_wg_size = std::max(
+ 32u,
+ std::min<uint32_t>(nwg * 32u, ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup));
+ ggml_webgpu_flash_attn_vec_reduce_shader_lib_context reduce_shader_ctx = {
+ .key =
+ {
+ .head_dim_v = (uint32_t) V->ne[0],
+ .wg_size = reduce_wg_size,
+ },
+ .max_wg_size = reduce_wg_size,
+ };
+ reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx);
+
+ reduce_params = {
+ (uint32_t) nrows, // nrows
+ (uint32_t) Q->ne[1], // seq_len_q
+ (uint32_t) Q->ne[2], // n_heads
+ (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), // offset_dst
+ nwg, // nwg
+ 0u, // tmp_data_base
+ (uint32_t) tmp_stats_base, // tmp_stats_base
+ };
+
+ reduce_entries = {
+ { .binding = 0, .buffer = tmp_buf, .offset = tmp_bind_offset, .size = tmp_size_bytes },
+ { .binding = 1,
+ .buffer = ggml_webgpu_tensor_buf(dst),
+ .offset = ggml_webgpu_tensor_align_offset(ctx, dst),
+ .size = ggml_webgpu_tensor_binding_size(ctx, dst) },
+ };
+ }
+
+ const uint64_t split_wg_total = (uint64_t) wg_x * nwg;
+ GGML_ASSERT(split_wg_total <= UINT32_MAX);
+ std::vector<webgpu_pipeline> pipelines;
+ std::vector<std::vector<uint32_t>> params_list;
+ std::vector<std::vector<wgpu::BindGroupEntry>> entries_list;
+ std::vector<std::pair<uint32_t, uint32_t>> workgroups_list;
+
+ if (use_blk) {
+ pipelines.push_back(blk_pipeline);
+ params_list.push_back(std::move(blk_params));
+ entries_list.push_back(std::move(blk_entries));
+ workgroups_list.push_back({ blk_nblk0, blk_nblk1 * blk_batch_count });
+ }
+ pipelines.push_back(pipeline);
+ params_list.push_back(std::move(split_params));
+ entries_list.push_back(std::move(split_entries));
+ workgroups_list.push_back({ (uint32_t) split_wg_total, 1u });
+ if (use_vec_reduce) {
+ pipelines.push_back(reduce_pipeline);
+ params_list.push_back(std::move(reduce_params));
+ entries_list.push_back(std::move(reduce_entries));
+ workgroups_list.push_back({ (uint32_t) nrows, 1u });
+ }
+
+ return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list,
+ entries_list, workgroups_list);
+ }
+
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
-#endif
static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
bool is_unary = dst->op == GGML_OP_UNARY;
std::vector<webgpu_submission> subs;
uint32_t num_batched_kernels = 0;
bool contains_set_rows = false;
-
for (int i = 0; i < cgraph->n_nodes; i++) {
if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) {
contains_set_rows = true;
}
}
break;
+ case GGML_OP_FLASH_ATTN_EXT:
+ {
+ const ggml_tensor * Q = tensor->src[0];
+ const ggml_tensor * K = tensor->src[1];
+ const ggml_tensor * V = tensor->src[2];
+ const ggml_tensor * mask = tensor->src[3];
+ const ggml_tensor * sinks = tensor->src[4];
+ if (Q && K && V) {
+ GGML_UNUSED(sinks);
+ const bool kv_direct = (K->type == GGML_TYPE_F16) &&
+ (Q->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k == 0) &&
+ (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
+ const bool kv_vec_type_supported =
+ K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0;
+ const bool use_vec =
+ (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported &&
+ (V->type == K->type);
+ if (use_vec) {
+ const uint32_t sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m;
+ const uint32_t sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n;
+ const size_t limit_bytes =
+ ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
+ const size_t q_tile = sg_mat_m;
+ const size_t base_q_bytes =
+ (Q->ne[0] + V->ne[0]) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
+ 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
+ size_t bytes_per_kv = 0;
+ if (!kv_direct) {
+ bytes_per_kv += std::max(Q->ne[0], V->ne[0]);
+ }
+ if (mask != nullptr) {
+ bytes_per_kv += q_tile;
+ }
+ bytes_per_kv += q_tile;
+ bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;
+ uint32_t kv_tile =
+ ((limit_bytes - base_q_bytes) / bytes_per_kv / sg_mat_n) * sg_mat_n;
+ kv_tile = std::max(sg_mat_n, std::min(32u, kv_tile));
+ kv_tile = (kv_tile / sg_mat_n) * sg_mat_n;
+ if (kv_direct) {
+ GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
+ while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
+ kv_tile -= sg_mat_n;
+ }
+ }
+
+ const uint32_t vec_nwg_cap = std::max(
+ 1u, std::min<uint32_t>(32u, ctx->webgpu_global_ctx->capabilities.max_subgroup_size));
+ uint32_t nwg = 1u;
+ const uint64_t kv_span = (uint64_t) std::max(1u, kv_tile);
+ while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) {
+ nwg <<= 1;
+ }
+ nwg = std::min(nwg, vec_nwg_cap);
+
+ const size_t align = ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment;
+ const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3];
+ if (nwg > 1u) {
+ const uint64_t tmp_data_elems = nrows * (uint64_t) V->ne[0] * nwg;
+ const uint64_t tmp_stats_elems = nrows * 2u * nwg;
+ const size_t tmp_size_bytes = ROUNDUP_POW2(
+ (tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT);
+ res += tmp_size_bytes + align;
+ }
+ if (mask != nullptr) {
+ const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile);
+ const uint32_t blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], 1u);
+ const uint32_t stride_mask3 =
+ (uint32_t) (mask->nb[3] / ggml_type_size(mask->type));
+ const uint32_t blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u;
+ const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count;
+ const size_t blk_size_bytes =
+ ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT);
+ res += blk_size_bytes + align;
+ }
+ res = ROUNDUP_POW2(res, WEBGPU_STORAGE_BUF_BINDING_MULT);
+ }
+ }
+ }
+ break;
default:
break;
}
--- /dev/null
+diagnostic(off, subgroup_uniformity);
+enable f16;
+
+#define Q_TILE 1
+#define KV_TILE 32
+#define WG_SIZE 32
+
+struct Params {
+ offset_mask: u32,
+ seq_len_q: u32,
+ seq_len_kv: u32,
+ stride_mask3: u32,
+ // Number of KV blocks and Q blocks per batch.
+ // nblk0 = ceil(seq_len_kv / KV_TILE), nblk1 = ceil(seq_len_q / Q_TILE).
+ nblk0: u32,
+ nblk1: u32,
+};
+
+@group(0) @binding(0) var<storage, read> mask: array<f16>;
+@group(0) @binding(1) var<storage, read_write> blk: array<u32>;
+@group(0) @binding(2) var<uniform> params: Params;
+
+const MASK_MIN: f32 = -65504.0;
+const MASK_MAX: f32 = 65504.0;
+var<workgroup> wg_min: array<f32, WG_SIZE>;
+var<workgroup> wg_max: array<f32, WG_SIZE>;
+var<workgroup> wg_any: array<u32, WG_SIZE>;
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
+ @builtin(local_invocation_id) local_id: vec3<u32>) {
+ // Dispatch mapping:
+ // - x indexes KV blocks
+ // - y flattens (batch_idx, q_blk) as y = batch_idx * nblk1 + q_blk
+ let kv_blk = wg_id.x;
+ let y = wg_id.y;
+ let q_blk = y % params.nblk1;
+ let batch_idx = y / params.nblk1;
+ if (kv_blk >= params.nblk0) {
+ return;
+ }
+
+ let q_start = q_blk * Q_TILE;
+ let k_start = kv_blk * KV_TILE;
+
+ let mask_batch = select(0u, batch_idx, params.stride_mask3 > 0u);
+ let mask_batch_base = params.offset_mask + mask_batch * params.stride_mask3;
+
+ // We keep min/max to classify:
+ // - fully masked (max <= MASK_MIN)
+ // - all-zero mask (min == 0 && max == 0)
+ // - mixed/general mask
+ var local_min = MASK_MAX;
+ var local_max = -MASK_MAX;
+ var local_any = 0u;
+
+ for (var q_rel = 0u; q_rel < Q_TILE; q_rel += 1u) {
+ let q_row = q_start + q_rel;
+ if (q_row >= params.seq_len_q) {
+ continue;
+ }
+ let row_base = mask_batch_base + q_row * params.seq_len_kv;
+ for (var k_rel = local_id.x; k_rel < KV_TILE; k_rel += WG_SIZE) {
+ let k_col = k_start + k_rel;
+ if (k_col >= params.seq_len_kv) {
+ continue;
+ }
+ let mv = f32(mask[row_base + k_col]);
+ local_min = min(local_min, mv);
+ local_max = max(local_max, mv);
+ local_any = 1u;
+ }
+ }
+
+ wg_min[local_id.x] = local_min;
+ wg_max[local_id.x] = local_max;
+ wg_any[local_id.x] = local_any;
+ workgroupBarrier();
+
+ // Thread 0 writes one state per block.
+ if (local_id.x == 0u) {
+ var mmin = wg_min[0];
+ var mmax = wg_max[0];
+ var many = wg_any[0];
+ for (var i = 1u; i < WG_SIZE; i += 1u) {
+ mmin = min(mmin, wg_min[i]);
+ mmax = max(mmax, wg_max[i]);
+ many = max(many, wg_any[i]);
+ }
+
+ var state = 0u;
+ if (many != 0u) {
+ if (mmax <= MASK_MIN) {
+ state = 0u;
+ } else if (mmin == 0.0 && mmax == 0.0) {
+ state = 2u;
+ } else {
+ state = 1u;
+ }
+ }
+
+ let blk_idx = (batch_idx * params.nblk1 + q_blk) * params.nblk0 + kv_blk;
+ blk[blk_idx] = state;
+ }
+}
--- /dev/null
+diagnostic(off, subgroup_uniformity);
+enable f16;
+enable subgroups;
+
+// Default values
+#define HEAD_DIM_V 64
+#define WG_SIZE 128
+
+struct Params {
+ nrows: u32,
+ seq_len_q: u32,
+ n_heads: u32,
+ offset_dst: u32,
+ nwg: u32,
+ tmp_data_base: u32,
+ tmp_stats_base: u32,
+};
+
+@group(0) @binding(0) var<storage, read_write> tmp: array<f32>;
+@group(0) @binding(1) var<storage, read_write> dst: array<vec4<f32>>;
+@group(0) @binding(2) var<uniform> params: Params;
+
+const FLOAT_MIN: f32 = -1.0e9;
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
+ @builtin(subgroup_id) subgroup_id: u32,
+ @builtin(num_subgroups) num_subgroups: u32,
+ @builtin(subgroup_size) subgroup_size: u32,
+ @builtin(subgroup_invocation_id) sg_inv_id: u32) {
+ let rid = wg_id.x;
+ if (rid >= params.nrows) {
+ return;
+ }
+
+ let rows_per_batch = params.n_heads * params.seq_len_q;
+ let batch_idx = rid / rows_per_batch;
+ let rem = rid % rows_per_batch;
+ let head_idx = rem / params.seq_len_q;
+ let q_row = rem % params.seq_len_q;
+
+ let dst2_stride = HEAD_DIM_V * params.n_heads;
+ let dst3_stride = dst2_stride * params.seq_len_q;
+ let row_base = params.offset_dst + batch_idx * dst3_stride + q_row * dst2_stride + head_idx * HEAD_DIM_V;
+
+ let thread = sg_inv_id;
+ if (params.nwg > subgroup_size) {
+ return;
+ }
+
+ let stats_base = params.tmp_stats_base + rid * (2u * params.nwg);
+ let active_thread = thread < params.nwg;
+ let si = select(0.0, tmp[stats_base + 2u * thread + 0u], active_thread);
+ let mi = select(FLOAT_MIN, tmp[stats_base + 2u * thread + 1u], active_thread);
+ let m = subgroupMax(mi);
+ let ms = select(0.0, exp(mi - m), active_thread);
+ let s = subgroupAdd(si * ms);
+ let inv_s = select(0.0, 1.0 / s, s != 0.0);
+
+ let row_tmp_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg);
+ for (var elem_base = subgroup_id * 4u; elem_base < HEAD_DIM_V; elem_base += num_subgroups * 4u) {
+ var weighted = vec4<f32>(0.0, 0.0, 0.0, 0.0);
+ if (active_thread) {
+ let src = row_tmp_base + thread * HEAD_DIM_V + elem_base;
+ weighted = vec4<f32>(tmp[src + 0u], tmp[src + 1u], tmp[src + 2u], tmp[src + 3u]) * ms;
+ }
+
+ let sum_x = subgroupAdd(weighted.x);
+ let sum_y = subgroupAdd(weighted.y);
+ let sum_z = subgroupAdd(weighted.z);
+ let sum_w = subgroupAdd(weighted.w);
+
+ if (thread == 0u) {
+ let dst_vec_index = (row_base + elem_base) >> 2u;
+ dst[dst_vec_index] = vec4<f32>(sum_x, sum_y, sum_z, sum_w) * inv_s;
+ }
+ }
+}
--- /dev/null
+diagnostic(off, chromium.subgroup_matrix_uniformity);
+diagnostic(off, subgroup_uniformity);
+enable f16;
+enable subgroups;
+enable chromium_experimental_subgroup_matrix;
+
+#ifdef KV_F32
+#define KV_TYPE f32
+#else
+#define KV_TYPE f16
+#endif
+
+#define HEAD_DIM_QK 64
+#define HEAD_DIM_V 64
+
+
+#define SG_MAT_M 8
+#define SG_MAT_N 8
+#define SG_MAT_K 8
+
+#define Q_TILE SG_MAT_M
+#define KV_TILE 16
+#define WG_SIZE 64
+#ifndef VEC_NE
+#define VEC_NE 4u
+#endif
+
+#define KV_BLOCKS (KV_TILE / SG_MAT_N)
+
+#define BLOCK_SIZE 32
+#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE)
+#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE)
+#if defined(KV_Q4_0)
+#define NQ 16
+#define F16_PER_BLOCK 9
+#define WEIGHTS_PER_F16 4
+#elif defined(KV_Q8_0)
+#define NQ 8
+#define F16_PER_BLOCK 17
+#define WEIGHTS_PER_F16 2
+#endif
+#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16)
+
+fn get_byte(value: u32, index: u32) -> u32 {
+ return (value >> (index * 8)) & 0xFF;
+}
+
+fn get_byte_i32(value: u32, index: u32) -> i32 {
+ return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;
+}
+
+struct Params {
+ offset_q: u32,
+ offset_k: u32,
+ offset_v: u32,
+ offset_mask: u32,
+ offset_sinks: u32,
+ offset_dst: u32,
+
+ // shapes of Q/K/V
+ n_heads: u32,
+ seq_len_q: u32,
+ seq_len_kv: u32,
+
+ // strides (in elements)
+ stride_q1: u32,
+ stride_q2: u32,
+ stride_q3: u32,
+ stride_k1: u32,
+ stride_k2: u32,
+ stride_k3: u32,
+ stride_v1: u32,
+ stride_v2: u32,
+ stride_v3: u32,
+ stride_mask3: u32,
+
+ // repeat factors for K/V, e.g., MHA vs. MQA vs. GQA
+ q_per_kv: u32,
+
+ // softmax params
+ scale: f32,
+ max_bias: f32,
+ logit_softcap: f32,
+ n_head_log2: f32,
+ m0: f32,
+ m1: f32,
+
+#ifdef BLK
+ blk_base: u32,
+ blk_nblk0: u32,
+ blk_nblk1: u32,
+#endif
+
+ tmp_data_base: u32,
+ tmp_stats_base: u32,
+ nwg: u32,
+};
+
+@group(0) @binding(0) var<storage, read_write> Q: array<f32>;
+#if defined(KV_Q4_0) || defined(KV_Q8_0)
+@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
+#else
+@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>;
+#endif
+#if defined(KV_Q4_0) || defined(KV_Q8_0)
+@group(0) @binding(2) var<storage, read_write> V: array<KV_TYPE>;
+#else
+@group(0) @binding(2) var<storage, read_write> V: array<vec4<KV_TYPE>>;
+#endif
+#if defined(MASK) && defined(SINKS)
+@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
+@group(0) @binding(4) var<storage, read_write> sinks: array<f32>;
+#ifdef BLK
+#define BLK_BINDING 5
+#define TMP_BINDING 6
+#define DST_BINDING 7
+#define PARAMS_BINDING 8
+#else
+#define TMP_BINDING 5
+#define DST_BINDING 6
+#define PARAMS_BINDING 7
+#endif
+#elif defined(MASK)
+@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
+#ifdef BLK
+#define BLK_BINDING 4
+#define TMP_BINDING 5
+#define DST_BINDING 6
+#define PARAMS_BINDING 7
+#else
+#define TMP_BINDING 4
+#define DST_BINDING 5
+#define PARAMS_BINDING 6
+#endif
+#elif defined(SINKS)
+@group(0) @binding(3) var<storage, read_write> sinks: array<f32>;
+#define TMP_BINDING 4
+#define DST_BINDING 5
+#define PARAMS_BINDING 6
+#else
+#define TMP_BINDING 3
+#define DST_BINDING 4
+#define PARAMS_BINDING 5
+#endif
+
+#ifdef BLK
+@group(0) @binding(BLK_BINDING) var<storage, read_write> blk: array<u32>;
+#endif
+@group(0) @binding(TMP_BINDING) var<storage, read_write> tmp: array<f32>;
+@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<f32>>;
+@group(0) @binding(PARAMS_BINDING) var<uniform> params: Params;
+
+// Just a very small float value.
+const FLOAT_MIN: f32 = -1.0e9;
+
+var<workgroup> q_shmem: array<f16, Q_TILE * HEAD_DIM_QK>;
+
+#ifndef KV_DIRECT
+const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V);
+// we can reuse the same shmem for K and V since we only need one at a time
+var<workgroup> kv_shmem: array<f16, kv_shmem_size>;
+#endif
+
+var<workgroup> o_shmem: array<f16, Q_TILE * HEAD_DIM_V>;
+
+#ifdef MASK
+// storage for mask values
+var<workgroup> mask_shmem: array<f16, Q_TILE * KV_TILE>;
+#endif
+
+// note that we reuse the same storage for both since we only need one at a time
+var<workgroup> inter_shmem: array<f16, Q_TILE * KV_TILE>;
+
+// Storage for row max and exp sum during online softmax
+var<workgroup> row_max_shmem: array<f32, Q_TILE>;
+var<workgroup> exp_sum_shmem: array<f32, Q_TILE>;
+var<workgroup> blk_state_wg: u32;
+
+fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32, has_bias: bool, apply_mask: bool) -> f32 {
+ var v = select(FLOAT_MIN,
+ f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]) * params.scale,
+ kv_idx < KV_TILE);
+#ifdef LOGIT_SOFTCAP
+ v = params.logit_softcap * tanh(v);
+#endif
+#ifdef MASK
+ if (apply_mask) {
+ var mask_val = select(0.0,f32(mask_shmem[q_tile_row * KV_TILE + kv_idx]), kv_idx < KV_TILE);
+ v += select(mask_val, slope * mask_val, has_bias);
+ }
+#endif
+ return v;
+}
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
+ @builtin(local_invocation_id) local_id: vec3<u32>,
+ @builtin(subgroup_id) subgroup_id: u32,
+ @builtin(subgroup_size) subgroup_size: u32,
+ @builtin(num_subgroups) num_subgroups: u32,
+ @builtin(subgroup_invocation_id) sg_inv_id: u32) {
+
+ // initialize row max for online softmax
+ for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) {
+ row_max_shmem[i] = FLOAT_MIN;
+ exp_sum_shmem[i] = 0.0;
+ }
+
+ for (var i = local_id.x; i < Q_TILE * HEAD_DIM_V; i += WG_SIZE) {
+ o_shmem[i] = 0.0;
+ }
+
+ // workgroups per head/batch
+ let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE;
+ let wg_per_batch = wg_per_head * params.n_heads;
+
+ let dst2_stride = HEAD_DIM_V * params.n_heads;
+ let dst3_stride = dst2_stride * params.seq_len_q;
+
+ let iwg = wg_id.x % params.nwg;
+ let base_wg_id = wg_id.x / params.nwg;
+
+ // batch index
+ let batch_idx = base_wg_id / wg_per_batch;
+ let q_batch_offset = params.offset_q + batch_idx * params.stride_q3;
+ let k_batch_offset = params.offset_k + batch_idx * params.stride_k3;
+ let v_batch_offset = params.offset_v + batch_idx * params.stride_v3;
+ let wg_in_batch = base_wg_id % wg_per_batch;
+
+ // head index
+ let head_idx = wg_in_batch / wg_per_head;
+ let q_head_offset = q_batch_offset + head_idx * params.stride_q2;
+ let k_head_idx = head_idx / params.q_per_kv;
+ let v_head_idx = k_head_idx;
+ let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2;
+ let v_head_offset = v_batch_offset + v_head_idx * params.stride_v2;
+
+ // starting Q row for this workgroup
+ let wg_in_head = wg_in_batch % wg_per_head;
+ let q_row_start = wg_in_head * Q_TILE;
+
+#ifdef MASK
+ // mask offset
+ let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv;
+#endif
+
+ let head = f32(head_idx);
+ let has_bias = params.max_bias > 0.0;
+ let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), has_bias);
+
+ // load q tile into shared memory
+ for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) {
+ let q_row = elem_idx / HEAD_DIM_QK;
+ let q_col = elem_idx % HEAD_DIM_QK;
+ let head_q_row = q_row_start + q_row;
+ let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1;
+ q_shmem[elem_idx] = f16(select(
+ 0.0,
+ Q[global_q_row_offset + q_col],
+ head_q_row < params.seq_len_q && q_col < HEAD_DIM_QK));
+ }
+
+ for (var kv_tile = iwg * KV_TILE; kv_tile < params.seq_len_kv; kv_tile += KV_TILE * params.nwg) {
+#ifdef BLK
+ let q_blk = q_row_start / Q_TILE;
+ let kv_blk = kv_tile / KV_TILE;
+ let blk_batch = select(0u, batch_idx, params.stride_mask3 > 0u);
+ let blk_idx = params.blk_base + (blk_batch * params.blk_nblk1 + q_blk) * params.blk_nblk0 + kv_blk;
+ let blk_state_local = blk[blk_idx];
+#else
+ let blk_state_local = 1u;
+#endif
+ if (local_id.x == 0u) {
+ blk_state_wg = blk_state_local;
+ }
+ workgroupBarrier();
+ let blk_state = blk_state_wg;
+ let skip_tile = blk_state == 0u;
+ for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
+ inter_shmem[elem_idx] = f16(0.0);
+ }
+
+ // load k tile into shared memory
+#if defined(KV_Q4_0)
+ for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {
+ let blck_idx = elem_idx / BLOCK_SIZE;
+ let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
+ let k_row = blck_idx / BLOCKS_K;
+ let global_k_row = kv_tile + k_row;
+ let block_k = blck_idx % BLOCKS_K;
+ let row_offset = k_row * HEAD_DIM_QK;
+
+ if (global_k_row < params.seq_len_kv) {
+ let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
+ let base_idx = global_block_idx * F16_PER_BLOCK;
+ let d = K[base_idx];
+ for (var j = 0u; j < F16_PER_THREAD; j += 2) {
+ let q_0 = K[base_idx + 1u + block_offset + j];
+ let q_1 = K[base_idx + 1u + block_offset + j + 1];
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
+ for (var k = 0u; k < 4u; k++) {
+ let q_byte = get_byte(q_packed, k);
+ let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
+ let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
+ let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
+ kv_shmem[row_offset + idx] = q_lo;
+ kv_shmem[row_offset + idx + 16u] = q_hi;
+ }
+ }
+ }
+ }
+#elif defined(KV_Q8_0)
+ for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {
+ let blck_idx = elem_idx / BLOCK_SIZE;
+ let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
+ let k_row = blck_idx / BLOCKS_K;
+ let global_k_row = kv_tile + k_row;
+ let block_k = blck_idx % BLOCKS_K;
+ let row_offset = k_row * HEAD_DIM_QK;
+
+ if (global_k_row < params.seq_len_kv) {
+ let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
+ let base_idx = global_block_idx * F16_PER_BLOCK;
+ let d = K[base_idx];
+ for (var j = 0u; j < F16_PER_THREAD; j += 2) {
+ let q_0 = K[base_idx + 1u + block_offset + j];
+ let q_1 = K[base_idx + 1u + block_offset + j + 1];
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
+ for (var k = 0u; k < 4u; k++) {
+ let q_byte = get_byte_i32(q_packed, k);
+ let q_val = f16(q_byte) * d;
+ let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
+ kv_shmem[row_offset + idx] = q_val;
+ }
+ }
+ }
+ }
+#elif defined(KV_DIRECT)
+ // Direct global loads for KV
+#else
+ for (var elem_idx = local_id.x * 4u; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * 4u) {
+ let k_row = elem_idx / HEAD_DIM_QK;
+ let k_col = elem_idx % HEAD_DIM_QK;
+ let global_k_row = kv_tile + k_row;
+ let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1;
+ let in_bounds = global_k_row < params.seq_len_kv && (k_col + 3u) < HEAD_DIM_QK;
+ let vec_idx = (global_k_row_offset + k_col) >> 2u;
+ let k4 = select(vec4<KV_TYPE>(0.0), K[vec_idx], in_bounds);
+ kv_shmem[elem_idx + 0u] = f16(k4.x);
+ kv_shmem[elem_idx + 1u] = f16(k4.y);
+ kv_shmem[elem_idx + 2u] = f16(k4.z);
+ kv_shmem[elem_idx + 3u] = f16(k4.w);
+ }
+#endif
+
+ workgroupBarrier();
+
+ // accumulate q block * k block into registers across the entire KV tile
+ if (!skip_tile) {
+ let num_of_threads = subgroup_size / VEC_NE;
+ let tx = sg_inv_id % num_of_threads;
+ let ty = sg_inv_id / num_of_threads;
+ for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) {
+ let global_q_row = q_row_start + q_tile_row;
+ if (global_q_row >= params.seq_len_q) {
+ continue;
+ }
+ let local_q_row_offset = q_tile_row * HEAD_DIM_QK;
+
+ for (var kv_base : u32 = 0u; kv_base < KV_TILE; kv_base += VEC_NE) {
+ let kv_idx = kv_base + ty;
+ var partial_sum: f32 = 0.0;
+ let kv_valid = kv_idx < KV_TILE && (kv_tile + kv_idx) < params.seq_len_kv;
+ if (kv_valid) {
+ for (var i = tx; i < (HEAD_DIM_QK / 4u); i += num_of_threads) {
+ let q_off = local_q_row_offset + i * 4u;
+
+ let qv = vec4<f32>(
+ f32(q_shmem[q_off + 0u]),
+ f32(q_shmem[q_off + 1u]),
+ f32(q_shmem[q_off + 2u]),
+ f32(q_shmem[q_off + 3u]));
+#ifdef KV_DIRECT
+ let idx = k_head_offset + (kv_tile + kv_idx) * params.stride_k1 + (i * 4u);
+ let kv = vec4<f32>(K[idx >> 2u]);
+#else
+ let idx = kv_idx * HEAD_DIM_QK + (i * 4u);
+ let kv = vec4<f32>(
+ f32(kv_shmem[idx + 0u]),
+ f32(kv_shmem[idx + 1u]),
+ f32(kv_shmem[idx + 2u]),
+ f32(kv_shmem[idx + 3u]));
+#endif
+ partial_sum += dot(qv, kv);
+ }
+ }
+ var sum = partial_sum;
+ // Reduce over tx threads (NL) for this ty stripe.
+ var tx_delta = num_of_threads >> 1u;
+ loop {
+ if (tx_delta == 0u) {
+ break;
+ }
+ let sh = subgroupShuffleDown(sum, tx_delta);
+ if (tx < tx_delta) {
+ sum += sh;
+ }
+ tx_delta >>= 1u;
+ }
+
+ let sum_bcast = subgroupShuffle(sum, num_of_threads * ty);
+ if (tx == 0u && kv_valid) {
+ let dst_idx = q_tile_row * KV_TILE + kv_idx;
+ inter_shmem[dst_idx] = f16(sum_bcast);
+ }
+ }
+ }
+ }
+
+
+#ifdef MASK
+ let apply_mask = !skip_tile && (blk_state != 2u);
+ if (apply_mask) {
+ // load mask tile into shared memory for this KV block
+ for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
+ let mask_row = elem_idx / KV_TILE;
+ let mask_col = elem_idx % KV_TILE;
+ let global_q_row = q_row_start + mask_row;
+ let global_k_col = kv_tile + mask_col;
+ let mask_in_bounds = global_q_row < params.seq_len_q && global_k_col < params.seq_len_kv;
+ let mask_idx = mask_global_offset + mask_row * params.seq_len_kv + global_k_col;
+ mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds);
+ }
+ }
+#else
+ let apply_mask = false;
+#endif
+
+ workgroupBarrier();
+
+ // online softmax
+ if (!skip_tile) {
+ for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) {
+ let global_q_row = q_row_start + q_tile_row;
+ if (global_q_row >= params.seq_len_q) {
+ break;
+ }
+
+ var prev_max = row_max_shmem[q_tile_row];
+ var final_max = prev_max;
+ // pass 1: compute final max across the full KV tile in chunks
+ for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
+ let kv_idx = kv_offset + sg_inv_id;
+ let kv_valid = kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE;
+ let softmax_term = select(FLOAT_MIN,
+ calc_softmax_term(kv_idx, q_tile_row, slope, has_bias, apply_mask),
+ kv_valid);
+ final_max = subgroupMax(max(final_max, softmax_term));
+ }
+
+ var total_exp_term: f32 = 0.0;
+ // pass 2: compute exp sum and write P using final_max
+ for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
+ let kv_idx = kv_offset + sg_inv_id;
+ let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope, has_bias, apply_mask);
+ let cur_p = select(0.0,
+ exp(softmax_term - final_max),
+ kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE);
+ total_exp_term += subgroupAdd(cur_p);
+ if (kv_idx < KV_TILE) {
+ inter_shmem[kv_idx + q_tile_row * KV_TILE] = f16(cur_p);
+ }
+ }
+
+ let cur_exp = exp(prev_max - final_max);
+
+ if (sg_inv_id == 0) {
+ row_max_shmem[q_tile_row] = final_max;
+ exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * cur_exp + total_exp_term;
+ }
+
+ for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
+ let idx = q_tile_row * HEAD_DIM_V + elem_idx;
+ o_shmem[idx] = f16(f32(o_shmem[idx]) * cur_exp);
+ }
+ }
+ }
+
+ // load v tile into shared memory
+#if defined(KV_Q4_0)
+ for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {
+ let blck_idx = elem_idx / BLOCK_SIZE;
+ let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
+ let v_row = blck_idx / BLOCKS_V;
+ let global_v_row = kv_tile + v_row;
+ let block_k = blck_idx % BLOCKS_V;
+ let row_offset = v_row * HEAD_DIM_V;
+
+ if (global_v_row < params.seq_len_kv) {
+ let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
+ let base_idx = global_block_idx * F16_PER_BLOCK;
+ let d = V[base_idx];
+ for (var j = 0u; j < F16_PER_THREAD; j += 2) {
+ let q_0 = V[base_idx + 1u + block_offset + j];
+ let q_1 = V[base_idx + 1u + block_offset + j + 1];
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
+ for (var k = 0u; k < 4u; k++) {
+ let q_byte = get_byte(q_packed, k);
+ let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
+ let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
+ let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
+ kv_shmem[row_offset + idx] = q_lo;
+ kv_shmem[row_offset + idx + 16u] = q_hi;
+ }
+ }
+ }
+ }
+#elif defined(KV_Q8_0)
+ for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {
+ let blck_idx = elem_idx / BLOCK_SIZE;
+ let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
+ let v_row = blck_idx / BLOCKS_V;
+ let global_v_row = kv_tile + v_row;
+ let block_k = blck_idx % BLOCKS_V;
+ let row_offset = v_row * HEAD_DIM_V;
+
+ if (global_v_row < params.seq_len_kv) {
+ let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
+ let base_idx = global_block_idx * F16_PER_BLOCK;
+ let d = V[base_idx];
+ for (var j = 0u; j < F16_PER_THREAD; j += 2) {
+ let q_0 = V[base_idx + 1u + block_offset + j];
+ let q_1 = V[base_idx + 1u + block_offset + j + 1];
+ let q_packed = bitcast<u32>(vec2(q_0, q_1));
+ for (var k = 0u; k < 4u; k++) {
+ let q_byte = get_byte_i32(q_packed, k);
+ let q_val = f16(q_byte) * d;
+ let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
+ kv_shmem[row_offset + idx] = q_val;
+ }
+ }
+ }
+ }
+#elif defined(KV_DIRECT)
+ // Direct global loads for KV
+#else
+ for (var elem_idx = local_id.x * 4u; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * 4u) {
+ let v_row = elem_idx / HEAD_DIM_V;
+ let v_col = elem_idx % HEAD_DIM_V;
+ let global_v_row = kv_tile + v_row;
+ let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1;
+ let in_bounds = global_v_row < params.seq_len_kv && (v_col + 3u) < HEAD_DIM_V;
+ let vec_idx = (global_v_row_offset + v_col) >> 2u;
+ let v4 = select(vec4<KV_TYPE>(0.0), V[vec_idx], in_bounds);
+ kv_shmem[elem_idx + 0u] = f16(v4.x);
+ kv_shmem[elem_idx + 1u] = f16(v4.y);
+ kv_shmem[elem_idx + 2u] = f16(v4.z);
+ kv_shmem[elem_idx + 3u] = f16(v4.w);
+ }
+#endif
+
+ workgroupBarrier();
+
+ if (!skip_tile) {
+ // we have P (Q_TILE x KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem
+ // we want to compute O += P * V across the full KV tile
+ let ne_threads : u32 = VEC_NE;
+ let nl_threads = max(1u, subgroup_size / ne_threads);
+ let tx_pv = sg_inv_id % nl_threads;
+ let ty_pv = sg_inv_id / nl_threads;
+ for (var q_tile_row = subgroup_id;
+ q_tile_row < Q_TILE;
+ q_tile_row += num_subgroups) {
+ for (var vec_col = tx_pv; vec_col < (HEAD_DIM_V / 4u); vec_col += nl_threads) {
+ var lo = vec4<f32>(0.0, 0.0, 0.0, 0.0);
+ for (var cc = 0u; cc < KV_TILE / ne_threads; cc += 1u) {
+ let kv_idx = cc * ne_threads + ty_pv;
+ let v_row = kv_tile + kv_idx;
+ if (v_row >= params.seq_len_kv) {
+ continue;
+ }
+
+ let p = f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]);
+#ifdef KV_DIRECT
+ let v_idx = v_head_offset + v_row * params.stride_v1 + vec_col * 4u;
+ let v4 = vec4<f32>(V[v_idx >> 2u]);
+#else
+ let v_idx = kv_idx * HEAD_DIM_V + vec_col * 4u;
+ let v4 = vec4<f32>(
+ f32(kv_shmem[v_idx + 0u]),
+ f32(kv_shmem[v_idx + 1u]),
+ f32(kv_shmem[v_idx + 2u]),
+ f32(kv_shmem[v_idx + 3u]));
+#endif
+ lo += p * v4;
+ }
+
+ var lo_x = lo.x;
+ var lo_y = lo.y;
+ var lo_z = lo.z;
+ var lo_w = lo.w;
+ // Reduce over ty threads (NE) for this tx thread.
+ var ty_delta = ne_threads >> 1u;
+ loop {
+ if (ty_delta == 0u) {
+ break;
+ }
+ let thread_delta = ty_delta * nl_threads;
+ let shx = subgroupShuffleDown(lo_x, thread_delta);
+ let shy = subgroupShuffleDown(lo_y, thread_delta);
+ let shz = subgroupShuffleDown(lo_z, thread_delta);
+ let shw = subgroupShuffleDown(lo_w, thread_delta);
+ if (ty_pv < ty_delta) {
+ lo_x += shx;
+ lo_y += shy;
+ lo_z += shz;
+ lo_w += shw;
+ }
+ ty_delta >>= 1u;
+ }
+
+ if (ty_pv == 0u) {
+ let elem_base = vec_col * 4u;
+ let o_base_idx = q_tile_row * HEAD_DIM_V + elem_base;
+ o_shmem[o_base_idx + 0u] = f16(f32(o_shmem[o_base_idx + 0u]) + lo_x);
+ o_shmem[o_base_idx + 1u] = f16(f32(o_shmem[o_base_idx + 1u]) + lo_y);
+ o_shmem[o_base_idx + 2u] = f16(f32(o_shmem[o_base_idx + 2u]) + lo_z);
+ o_shmem[o_base_idx + 3u] = f16(f32(o_shmem[o_base_idx + 3u]) + lo_w);
+ }
+ }
+ }
+ }
+
+ workgroupBarrier();
+ }
+
+
+#ifdef SINKS
+ // Sinks are global terms and must be applied exactly once across split workgroups.
+ if (iwg == 0u) {
+ for (var q_tile_row = subgroup_id;
+ q_tile_row < Q_TILE;
+ q_tile_row += num_subgroups) {
+ let global_q_row = q_row_start + q_tile_row;
+ if (global_q_row >= params.seq_len_q) {
+ break;
+ }
+
+ var prev_max = row_max_shmem[q_tile_row];
+
+ // for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum
+ let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0);
+ let new_max = subgroupMax(max(prev_max, sink_val));
+ let max_exp = exp(prev_max - new_max);
+ let sink_exp = exp(sink_val - new_max);
+
+ let sink_exp_sum = subgroupAdd(sink_exp);
+
+ if (sg_inv_id == 0) {
+ row_max_shmem[q_tile_row] = new_max;
+ exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * max_exp + sink_exp_sum;
+ }
+
+ for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
+ let idx = q_tile_row * HEAD_DIM_V + elem_idx;
+ o_shmem[idx] = f16(f32(o_shmem[idx]) * max_exp);
+ }
+ }
+ workgroupBarrier();
+ }
+#endif
+ let rows_per_batch = params.n_heads * params.seq_len_q;
+ for (var q_tile_row = subgroup_id;
+ q_tile_row < Q_TILE;
+ q_tile_row += num_subgroups) {
+
+ let global_q_row = q_row_start + q_tile_row;
+ if (global_q_row >= params.seq_len_q) { break; }
+
+ if (params.nwg == 1u) {
+ let exp_sum = exp_sum_shmem[q_tile_row];
+ let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0.0);
+ let row_base: u32 =
+ params.offset_dst + batch_idx * dst3_stride + global_q_row * dst2_stride + head_idx * HEAD_DIM_V;
+
+ for (var elem_base = sg_inv_id * 4u; elem_base < HEAD_DIM_V; elem_base += subgroup_size * 4u) {
+ let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u);
+ let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u);
+ let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u);
+ let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u);
+
+ let v = vec4<f32>(
+ f32(o_shmem[i0]) * scale,
+ f32(o_shmem[i1]) * scale,
+ f32(o_shmem[i2]) * scale,
+ f32(o_shmem[i3]) * scale
+ );
+
+ let dst_vec_index: u32 = (row_base + elem_base) >> 2u;
+ dst[dst_vec_index] = v;
+ }
+ } else {
+ let rid = batch_idx * rows_per_batch + head_idx * params.seq_len_q + global_q_row;
+ let tmp_row_data_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg) + iwg * HEAD_DIM_V;
+ let tmp_row_stats_base = params.tmp_stats_base + rid * (2u * params.nwg) + 2u * iwg;
+
+ for (var elem_base = sg_inv_id * 4u;
+ elem_base < HEAD_DIM_V;
+ elem_base += subgroup_size * 4u) {
+
+ let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u);
+ let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u);
+ let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u);
+ let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u);
+
+ let tbase = tmp_row_data_base + elem_base;
+ tmp[tbase + 0u] = f32(o_shmem[i0]);
+ tmp[tbase + 1u] = f32(o_shmem[i1]);
+ tmp[tbase + 2u] = f32(o_shmem[i2]);
+ tmp[tbase + 3u] = f32(o_shmem[i3]);
+ }
+
+ if (sg_inv_id == 0u) {
+ tmp[tmp_row_stats_base + 0u] = exp_sum_shmem[q_tile_row];
+ tmp[tmp_row_stats_base + 1u] = row_max_shmem[q_tile_row];
+ }
+ }
+ }
+}