]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml-webgpu: add vectorized flash attention (#20709)
authorZheyuan Chen <redacted>
Thu, 2 Apr 2026 17:40:42 +0000 (10:40 -0700)
committerGitHub <redacted>
Thu, 2 Apr 2026 17:40:42 +0000 (10:40 -0700)
* naive vectorized version

* add vectorized flash attention

* update vec version

* remove unused path and shader

* remove unused helper functions

* add comments

* remove pad path

* ggml-webgpu: fix flash-attn vec nwg=1 path and tighten vec specialization

* change back to vec4

* enable multi split

* enable vec path when:
- Q->ne[1] < 20
- Q->ne[0] % 32 == 0
- V->ne[0] % 4 == 0
- K->type == f16

* update flast_attn_vec_split.wgsl to reduce redundant workgroup barrier usage and use select

* enable vec path for q4 and q8

* flash-attn vec nwg=1 fast path (skip tmp/reduce staging)

* use packed f16 K loads in flash-attn vec split

* use packed f16 K loads in flash-attn vec split on host side

* tune flash-attn vec f16 VEC_NE by head dim

* cleanup

* cleanup

* keep host side clean

* cleanup host side

* change back to original host wait/submit behavior

* formatting

* reverted param-buffer pool r ecfactor

* add helper functions

* ggml-webgpu: move flash-attn vec pipeline caching back into shader lib

* ggml-webgpu: remove duplicate functions

* ggml-webgpu: reserve flash-attn vec scratch in dst buffer allocation

* ggml-webgpu: revert unrelated change

* ggml-webgpu: revert deleted comment

* disable uniformity check

* remove unnecessary change

* Update ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl

* Update ggml/src/ggml-webgpu/ggml-webgpu.cpp

---------

Co-authored-by: Reese Levine <redacted>
ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
ggml/src/ggml-webgpu/ggml-webgpu.cpp
ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl [new file with mode: 0644]
ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl [new file with mode: 0644]
ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl [new file with mode: 0644]

index a194ce84e2556f217db5cd1aa483c79ee3e8fca0..1c56c689312f4d643ecbbc13d2a73e2721484163 100644 (file)
@@ -95,6 +95,12 @@ struct ggml_webgpu_generic_shader_decisions {
     uint32_t wg_size = 0;
 };
 
+struct ggml_webgpu_processed_shader {
+    std::string           wgsl;
+    std::string           variant;
+    std::shared_ptr<void> decisions;
+};
+
 struct ggml_webgpu_ssm_conv_shader_decisions {
     uint32_t block_size;
     uint32_t tokens_per_wg;
@@ -384,11 +390,12 @@ struct ggml_webgpu_flash_attn_pipeline_key {
     bool      has_mask;
     bool      has_sinks;
     bool      uses_logit_softcap;
+    bool      use_vec;
 
     bool operator==(const ggml_webgpu_flash_attn_pipeline_key & other) const {
         return kv_type == other.kv_type && head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v &&
                kv_direct == other.kv_direct && has_mask == other.has_mask && has_sinks == other.has_sinks &&
-               uses_logit_softcap == other.uses_logit_softcap;
+               uses_logit_softcap == other.uses_logit_softcap && use_vec == other.use_vec;
     }
 };
 
@@ -402,6 +409,7 @@ struct ggml_webgpu_flash_attn_pipeline_key_hash {
         ggml_webgpu_hash_combine(seed, key.has_mask);
         ggml_webgpu_hash_combine(seed, key.has_sinks);
         ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
+        ggml_webgpu_hash_combine(seed, key.use_vec);
         return seed;
     }
 };
@@ -421,6 +429,115 @@ struct ggml_webgpu_flash_attn_shader_decisions {
     uint32_t wg_size = 0;
 };
 
+inline uint32_t ggml_webgpu_flash_attn_pick_vec_ne(const ggml_webgpu_flash_attn_pipeline_key & key) {
+    // Keep conservative defaults unless this is the f16 vec-split shape family.
+    if (key.kv_type != GGML_TYPE_F16 || key.head_dim_qk != key.head_dim_v) {
+        return 1u;
+    }
+
+    // Head-dim specializations used by the tuned vec f16 path.
+    switch (key.head_dim_qk) {
+        case 64: return 2u;
+        case 96: return 4u;
+        case 128: return 1u;
+        case 192: return 2u;
+        case 576: return 2u;
+        default: return 1u;
+    }
+}
+
+struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key {
+    uint32_t head_dim_v;
+    uint32_t wg_size;
+};
+
+struct ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash {
+    size_t operator()(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & key) const {
+        size_t seed = 0;
+        ggml_webgpu_hash_combine(seed, key.head_dim_v);
+        ggml_webgpu_hash_combine(seed, key.wg_size);
+        return seed;
+    }
+};
+
+inline bool operator==(const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & lhs,
+                       const ggml_webgpu_flash_attn_vec_reduce_pipeline_key & rhs) {
+    return lhs.head_dim_v == rhs.head_dim_v && lhs.wg_size == rhs.wg_size;
+}
+
+struct ggml_webgpu_flash_attn_vec_reduce_shader_lib_context {
+    ggml_webgpu_flash_attn_vec_reduce_pipeline_key key;
+    uint32_t                                       max_wg_size;
+};
+
+inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_vec_reduce_shader(
+    pre_wgsl::Preprocessor &                                     preprocessor,
+    const char *                                                 shader_src,
+    const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) {
+    std::vector<std::string> defines;
+    std::string              variant = "flash_attn_vec_reduce";
+
+    defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v));
+    variant += std::string("_hsv") + std::to_string(context.key.head_dim_v);
+
+    defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
+    variant += std::string("_wg") + std::to_string(context.max_wg_size);
+
+    ggml_webgpu_processed_shader result;
+    result.wgsl    = preprocessor.preprocess(shader_src, defines);
+    result.variant = variant;
+    return result;
+}
+
+struct ggml_webgpu_flash_attn_blk_pipeline_key {
+    uint32_t q_tile;
+    uint32_t kv_tile;
+
+    bool operator==(const ggml_webgpu_flash_attn_blk_pipeline_key & other) const {
+        return q_tile == other.q_tile && kv_tile == other.kv_tile;
+    }
+};
+
+struct ggml_webgpu_flash_attn_blk_pipeline_key_hash {
+    size_t operator()(const ggml_webgpu_flash_attn_blk_pipeline_key & key) const {
+        size_t seed = 0;
+        ggml_webgpu_hash_combine(seed, key.q_tile);
+        ggml_webgpu_hash_combine(seed, key.kv_tile);
+        return seed;
+    }
+};
+
+struct ggml_webgpu_flash_attn_blk_shader_lib_context {
+    ggml_webgpu_flash_attn_blk_pipeline_key key;
+    uint32_t                                max_wg_size;
+};
+
+inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_blk_shader(
+    pre_wgsl::Preprocessor &                                    preprocessor,
+    const char *                                                shader_src,
+    const ggml_webgpu_flash_attn_blk_shader_lib_context &       context) {
+    std::vector<std::string> defines;
+    std::string              variant = "flash_attn_vec_blk";
+
+    defines.push_back(std::string("Q_TILE=") + std::to_string(context.key.q_tile));
+    variant += std::string("_qt") + std::to_string(context.key.q_tile);
+
+    defines.push_back(std::string("KV_TILE=") + std::to_string(context.key.kv_tile));
+    variant += std::string("_kvt") + std::to_string(context.key.kv_tile);
+
+    uint32_t wg_size = 1;
+    while ((wg_size << 1) <= context.max_wg_size) {
+        wg_size <<= 1;
+    }
+    defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
+    variant += std::string("_wg") + std::to_string(wg_size);
+
+    ggml_webgpu_processed_shader result;
+    result.wgsl    = preprocessor.preprocess(shader_src, defines);
+    result.variant = variant;
+    return result;
+}
+
 // This is exposed because it's necessary in supports_op
 inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
                                                   uint32_t kv_tile,
@@ -659,6 +776,14 @@ class ggml_webgpu_shader_lib {
         repeat_pipelines;           // type
     std::unordered_map<ggml_webgpu_flash_attn_pipeline_key, webgpu_pipeline, ggml_webgpu_flash_attn_pipeline_key_hash>
         flash_attn_pipelines;
+    std::unordered_map<ggml_webgpu_flash_attn_vec_reduce_pipeline_key,
+                       webgpu_pipeline,
+                       ggml_webgpu_flash_attn_vec_reduce_pipeline_key_hash>
+        flash_attn_vec_reduce_pipelines;
+    std::unordered_map<ggml_webgpu_flash_attn_blk_pipeline_key,
+                       webgpu_pipeline,
+                       ggml_webgpu_flash_attn_blk_pipeline_key_hash>
+        flash_attn_blk_pipelines;
     std::unordered_map<ggml_webgpu_legacy_mul_mat_pipeline_key,
                        webgpu_pipeline,
                        ggml_webgpu_legacy_mul_mat_pipeline_key_hash>
@@ -1673,24 +1798,8 @@ class ggml_webgpu_shader_lib {
         return repeat_pipelines[key];
     }
 
-    webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) {
-        const bool has_mask  = context.src3 != nullptr;
-        const bool has_sinks = context.src4 != nullptr;
-
-        bool kv_direct = (context.src1->type == GGML_TYPE_F16) && (context.src0->ne[0] % context.sg_mat_k == 0) &&
-                         (context.src1->ne[1] % context.sg_mat_n == 0);
-
-        ggml_webgpu_flash_attn_pipeline_key key = {
-            .kv_type            = context.src1->type,
-            .head_dim_qk        = (uint32_t) context.src0->ne[0],
-            .head_dim_v         = (uint32_t) context.src2->ne[0],
-            .kv_direct          = kv_direct,
-            .has_mask           = has_mask,
-            .has_sinks          = has_sinks,
-            .uses_logit_softcap = (*(float *) &context.dst->op_params[2]) != 0.0f,
-        };
-
-        auto it = flash_attn_pipelines.find(key);
+    webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_flash_attn_shader_lib_context & context) {
+        auto it = flash_attn_pipelines.find(context.key);
         if (it != flash_attn_pipelines.end()) {
             return it->second;
         }
@@ -1698,7 +1807,7 @@ class ggml_webgpu_shader_lib {
         std::vector<std::string> defines;
         std::string              variant = "flash_attn";
 
-        switch (key.kv_type) {
+        switch (context.key.kv_type) {
             case GGML_TYPE_F32:
                 defines.push_back("KV_F32");
                 break;
@@ -1714,41 +1823,52 @@ class ggml_webgpu_shader_lib {
             default:
                 GGML_ABORT("Unsupported KV type for flash attention shader");
         }
-        variant += std::string("_") + ggml_type_name(key.kv_type);
+        variant += std::string("_") + ggml_type_name(context.key.kv_type);
 
-        if (key.has_mask) {
+        if (context.key.has_mask) {
             defines.push_back("MASK");
             variant += "_mask";
         }
-        if (key.has_sinks) {
+        if (context.key.has_sinks) {
             defines.push_back("SINKS");
             variant += "_sinks";
         }
-        if (key.uses_logit_softcap) {
+        if (context.key.uses_logit_softcap) {
             defines.push_back("LOGIT_SOFTCAP");
             variant += "_lgsc";
         }
-        if (key.kv_direct) {
+        if (context.key.kv_direct) {
             defines.push_back("KV_DIRECT");
             variant += "_kvdirect";
         }
+        if (context.key.has_mask && context.key.use_vec) {
+            defines.push_back("BLK");
+            variant += "_blk";
+        }
 
-        defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk));
-        variant += std::string("_hsqk") + std::to_string(key.head_dim_qk);
+        defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.key.head_dim_qk));
+        variant += std::string("_hsqk") + std::to_string(context.key.head_dim_qk);
 
-        defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v));
-        variant += std::string("_hsv") + std::to_string(key.head_dim_v);
+        defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v));
+        variant += std::string("_hsv") + std::to_string(context.key.head_dim_v);
 
         defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
         defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
         defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
 
-        uint32_t q_tile = context.sg_mat_m;
+        uint32_t q_tile  = context.sg_mat_m;
         uint32_t kv_tile =
-            std::min(ggml_webgpu_flash_attn_max_kv_tile({ key, context.sg_mat_m, context.sg_mat_n, context.sg_mat_k,
-                                                          context.wg_mem_limit_bytes, context.max_subgroup_size }),
+            std::min(ggml_webgpu_flash_attn_max_kv_tile(context),
                      context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
-        if (key.kv_direct) {
+        if (context.key.use_vec) {
+            q_tile  = 1;
+            kv_tile = std::max(context.sg_mat_n, std::min(32u, ggml_webgpu_flash_attn_max_kv_tile(context)));
+            kv_tile = (kv_tile / context.sg_mat_n) * context.sg_mat_n;
+            const uint32_t vec_ne = ggml_webgpu_flash_attn_pick_vec_ne(context.key);
+            defines.push_back(std::string("VEC_NE=") + std::to_string(vec_ne) + "u");
+        }
+        if (context.key.kv_direct) {
+            GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
             while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
                 kv_tile -= context.sg_mat_n;
             }
@@ -1757,19 +1877,51 @@ class ggml_webgpu_shader_lib {
         defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile));
         defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile));
 
-        uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
+        uint32_t wg_size = 0;
+        if (context.key.use_vec) {
+            wg_size = std::max(1u, std::min<uint32_t>(32u, context.max_subgroup_size));
+        } else {
+            wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
+        }
         defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
 
-        auto processed     = preprocessor.preprocess(wgsl_flash_attn, defines);
+        const char * shader_src = context.key.use_vec ? wgsl_flash_attn_vec_split : wgsl_flash_attn;
+        webgpu_pipeline pipeline =
+            ggml_webgpu_create_pipeline(device, preprocessor.preprocess(shader_src, defines), variant);
         auto decisions     = std::make_shared<ggml_webgpu_flash_attn_shader_decisions>();
         decisions->q_tile  = q_tile;
         decisions->kv_tile = kv_tile;
         decisions->wg_size = wg_size;
+        pipeline.context   = decisions;
+        flash_attn_pipelines[context.key] = pipeline;
+        return flash_attn_pipelines[context.key];
+    }
+
+    webgpu_pipeline get_flash_attn_blk_pipeline(const ggml_webgpu_flash_attn_blk_shader_lib_context & context) {
+        auto it = flash_attn_blk_pipelines.find(context.key);
+        if (it != flash_attn_blk_pipelines.end()) {
+            return it->second;
+        }
+
+        ggml_webgpu_processed_shader processed =
+            ggml_webgpu_preprocess_flash_attn_blk_shader(preprocessor, wgsl_flash_attn_vec_blk, context);
+        webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant);
+        flash_attn_blk_pipelines[context.key] = pipeline;
+        return flash_attn_blk_pipelines[context.key];
+    }
+
+    webgpu_pipeline get_flash_attn_vec_reduce_pipeline(
+        const ggml_webgpu_flash_attn_vec_reduce_shader_lib_context & context) {
+        auto it = flash_attn_vec_reduce_pipelines.find(context.key);
+        if (it != flash_attn_vec_reduce_pipelines.end()) {
+            return it->second;
+        }
 
-        webgpu_pipeline pipeline  = ggml_webgpu_create_pipeline(device, processed, variant);
-        pipeline.context          = decisions;
-        flash_attn_pipelines[key] = pipeline;
-        return flash_attn_pipelines[key];
+        ggml_webgpu_processed_shader processed =
+            ggml_webgpu_preprocess_flash_attn_vec_reduce_shader(preprocessor, wgsl_flash_attn_vec_reduce, context);
+        webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed.wgsl, processed.variant);
+        flash_attn_vec_reduce_pipelines[context.key] = pipeline;
+        return flash_attn_vec_reduce_pipelines[context.key];
     }
 
     webgpu_pipeline get_cpy_pipeline(const ggml_webgpu_shader_lib_context & context) {
index 1aa15b0507cc41b2beb722c5360dbf22a23512f8..e53281bfbbd4df39348dcf1769624e8b419ea3c3 100644 (file)
@@ -658,7 +658,6 @@ static webgpu_command ggml_backend_webgpu_build_multi(
     for (size_t i = 0; i < params_bufs_list.size(); i++) {
         ctx->queue.WriteBuffer(params_bufs_list[i], 0, params_list[i].data(), params_list[i].size() * sizeof(uint32_t));
     }
-
 #ifdef GGML_WEBGPU_GPU_PROFILE
     webgpu_gpu_profile_bufs ts_bufs = ctx->timestamp_query_buf_pool.alloc_bufs();
     if (ts_bufs.host_buf.GetMapState() == wgpu::BufferMapState::Mapped) {
@@ -1481,7 +1480,6 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
     return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y);
 }
 
-#ifndef __EMSCRIPTEN__
 static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
                                              ggml_tensor *    Q,
                                              ggml_tensor *    K,
@@ -1565,30 +1563,248 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
                         .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
                         .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
 
-    ggml_webgpu_shader_lib_context shader_lib_ctx = {
-        .src0               = Q,
-        .src1               = K,
-        .src2               = V,
-        .src3               = mask,
-        .src4               = sinks,
-        .dst                = dst,
-        .max_wg_size        = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
-        .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
+    const uint32_t k_offset_elems   = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type));
+    const uint32_t v_offset_elems   = (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type));
+    const bool     f16_vec4_aligned = (k_offset_elems % 4u == 0u) && (v_offset_elems % 4u == 0u);
+
+    const bool kv_direct = (K->type == GGML_TYPE_F16) && f16_vec4_aligned &&
+                           (Q->ne[0] % ctx->global_ctx->capabilities.sg_mat_k == 0) &&
+                           (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
+
+    const bool kv_vec_type_supported =
+        K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0;
+    const bool use_vec = (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported &&
+                         (K->type != GGML_TYPE_F16 || f16_vec4_aligned) && (V->type == K->type);
+    const uint32_t vec_nwg_cap =
+        std::max(1u, std::min<uint32_t>(32u, ctx->global_ctx->capabilities.max_subgroup_size));
+    const bool use_blk = use_vec && has_mask;
+
+    ggml_webgpu_flash_attn_pipeline_key key = {
+        .kv_type            = K->type,
+        .head_dim_qk        = (uint32_t) Q->ne[0],
+        .head_dim_v         = (uint32_t) V->ne[0],
+        .kv_direct          = kv_direct,
+        .has_mask           = static_cast<bool>(has_mask),
+        .has_sinks          = static_cast<bool>(has_sinks),
+        .uses_logit_softcap = logit_softcap != 0.0f,
+        .use_vec            = use_vec,
+    };
+
+    ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = {
+        .key                = key,
         .sg_mat_m           = ctx->global_ctx->capabilities.sg_mat_m,
         .sg_mat_n           = ctx->global_ctx->capabilities.sg_mat_n,
         .sg_mat_k           = ctx->global_ctx->capabilities.sg_mat_k,
+        .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize,
         .max_subgroup_size  = ctx->global_ctx->capabilities.max_subgroup_size,
     };
-
     webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx);
 
     auto * decisions = static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context.get());
 
     uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile);
     uint32_t wg_x        = wg_per_head * Q->ne[2] * Q->ne[3];  // wg per head * number of heads * number of batches
+
+    wgpu::Buffer blk_buf         = {};
+    uint64_t     blk_size_bytes  = 0;
+    uint32_t     blk_nblk0       = 0;
+    uint32_t     blk_nblk1       = 0;
+    uint32_t     blk_batch_count = 0;
+
+    if (use_vec) {
+        uint32_t       nwg     = 1u;
+        const uint64_t kv_span = (uint64_t) std::max(1u, decisions->kv_tile);
+        while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) {
+            nwg <<= 1;
+        }
+        nwg = std::min(nwg, vec_nwg_cap);
+        GGML_ASSERT(nwg <= ctx->global_ctx->capabilities.max_subgroup_size);
+        const uint64_t nrows          = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3];
+        const bool     use_vec_reduce = nwg > 1u;
+        GGML_ASSERT(nrows <= UINT32_MAX);
+
+        uint64_t     tmp_stats_base  = 0;
+        uint64_t     tmp_size_bytes  = 0;
+        wgpu::Buffer tmp_buf         = {};
+        uint64_t     tmp_bind_offset = 0;
+        uint64_t     tmp_bind_size   = 0;
+        const size_t align_bytes     = ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment;
+        const size_t dst_offset      = ggml_webgpu_tensor_offset(dst);
+        size_t       scratch_offset  = ROUNDUP_POW2(dst_offset + ggml_nbytes(dst), align_bytes);
+
+        if (use_vec_reduce) {
+            const uint64_t tmp_data_elems  = nrows * (uint64_t) V->ne[0] * nwg;
+            const uint64_t tmp_stats_elems = nrows * 2u * nwg;
+            tmp_stats_base                 = tmp_data_elems;
+            tmp_size_bytes =
+                ROUNDUP_POW2((tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT);
+            GGML_ASSERT(tmp_stats_base <= UINT32_MAX);
+            tmp_buf         = ggml_webgpu_tensor_buf(dst);
+            tmp_bind_offset = scratch_offset;
+            tmp_bind_size   = tmp_size_bytes;
+            scratch_offset  = ROUNDUP_POW2(scratch_offset + tmp_size_bytes, align_bytes);
+        } else {
+            // nwg==1 writes final dst directly in vec-split; keep tmp binding valid without extra allocation.
+            tmp_buf         = ggml_webgpu_tensor_buf(dst);
+            tmp_bind_offset = ggml_webgpu_tensor_align_offset(ctx, dst);
+            tmp_bind_size   = ggml_webgpu_tensor_binding_size(ctx, dst);
+        }
+
+        webgpu_pipeline                   blk_pipeline;
+        std::vector<uint32_t>             blk_params;
+        std::vector<wgpu::BindGroupEntry> blk_entries;
+        if (use_blk) {
+            GGML_ASSERT(has_mask);
+
+            blk_nblk0       = CEIL_DIV((uint32_t) K->ne[1], decisions->kv_tile);
+            blk_nblk1       = CEIL_DIV((uint32_t) Q->ne[1], decisions->q_tile);
+            blk_buf         = ggml_webgpu_tensor_buf(dst);
+            const uint32_t stride_mask3 = (uint32_t) (mask->nb[3] / ggml_type_size(mask->type));
+            blk_batch_count             = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u;
+            const uint64_t blk_elems    = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count;
+            blk_size_bytes              = ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT);
+            ggml_webgpu_flash_attn_blk_shader_lib_context blk_shader_ctx = {
+                .key =
+                    {
+                        .q_tile  = decisions->q_tile,
+                        .kv_tile = decisions->kv_tile,
+                    },
+                .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup,
+            };
+            blk_pipeline = ctx->shader_lib->get_flash_attn_blk_pipeline(blk_shader_ctx);
+
+            blk_params = {
+                (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)),  // offset_mask
+                (uint32_t) Q->ne[1],                                                                   // seq_len_q
+                (uint32_t) K->ne[1],                                                                   // seq_len_kv
+                stride_mask3,                                                                          // stride_mask3
+                blk_nblk0,                                                                             // nblk0
+                blk_nblk1,                                                                             // nblk1
+            };
+            blk_entries = {
+                { .binding = 0,
+                 .buffer  = ggml_webgpu_tensor_buf(mask),
+                 .offset  = ggml_webgpu_tensor_align_offset(ctx, mask),
+                 .size    = ggml_webgpu_tensor_binding_size(ctx, mask) },
+                { .binding = 1, .buffer = blk_buf, .offset = scratch_offset, .size = blk_size_bytes },
+            };
+            scratch_offset = ROUNDUP_POW2(scratch_offset + blk_size_bytes, align_bytes);
+        }
+
+        std::vector<uint32_t> split_params = params;
+        if (use_blk) {
+            split_params.push_back(0u);                     // blk_base
+            split_params.push_back(blk_nblk0);              // blk_nblk0
+            split_params.push_back(blk_nblk1);              // blk_nblk1
+        }
+        split_params.push_back(0u);                         // tmp_data_base
+        split_params.push_back((uint32_t) tmp_stats_base);  // tmp_stats_base
+        split_params.push_back(nwg);                        // nwg
+
+        std::vector<wgpu::BindGroupEntry> split_entries = {
+            { .binding = 0,
+             .buffer  = ggml_webgpu_tensor_buf(Q),
+             .offset  = ggml_webgpu_tensor_align_offset(ctx, Q),
+             .size    = ggml_webgpu_tensor_binding_size(ctx, Q) },
+            { .binding = 1,
+             .buffer  = ggml_webgpu_tensor_buf(K),
+             .offset  = ggml_webgpu_tensor_align_offset(ctx, K),
+             .size    = ggml_webgpu_tensor_binding_size(ctx, K) },
+            { .binding = 2,
+             .buffer  = ggml_webgpu_tensor_buf(V),
+             .offset  = ggml_webgpu_tensor_align_offset(ctx, V),
+             .size    = ggml_webgpu_tensor_binding_size(ctx, V) },
+        };
+        uint32_t split_binding_index = 3;
+        if (has_mask) {
+            split_entries.push_back({ .binding = split_binding_index++,
+                                      .buffer  = ggml_webgpu_tensor_buf(mask),
+                                      .offset  = ggml_webgpu_tensor_align_offset(ctx, mask),
+                                      .size    = ggml_webgpu_tensor_binding_size(ctx, mask) });
+        }
+        if (has_sinks) {
+            split_entries.push_back({ .binding = split_binding_index++,
+                                      .buffer  = ggml_webgpu_tensor_buf(sinks),
+                                      .offset  = ggml_webgpu_tensor_align_offset(ctx, sinks),
+                                      .size    = ggml_webgpu_tensor_binding_size(ctx, sinks) });
+        }
+        if (use_blk) {
+            split_entries.push_back(
+                { .binding = split_binding_index++, .buffer = blk_buf, .offset = blk_entries[1].offset, .size = blk_size_bytes });
+        }
+        split_entries.push_back(
+            { .binding = split_binding_index++, .buffer = tmp_buf, .offset = tmp_bind_offset, .size = tmp_bind_size });
+        split_entries.push_back({ .binding = split_binding_index++,
+                                  .buffer  = ggml_webgpu_tensor_buf(dst),
+                                  .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
+                                  .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
+
+        webgpu_pipeline                   reduce_pipeline;
+        std::vector<uint32_t>             reduce_params;
+        std::vector<wgpu::BindGroupEntry> reduce_entries;
+        if (use_vec_reduce) {
+            const uint32_t reduce_wg_size = std::max(
+                32u,
+                std::min<uint32_t>(nwg * 32u, ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup));
+            ggml_webgpu_flash_attn_vec_reduce_shader_lib_context reduce_shader_ctx = {
+                .key =
+                    {
+                        .head_dim_v = (uint32_t) V->ne[0],
+                        .wg_size    = reduce_wg_size,
+                    },
+                .max_wg_size = reduce_wg_size,
+            };
+            reduce_pipeline = ctx->shader_lib->get_flash_attn_vec_reduce_pipeline(reduce_shader_ctx);
+
+            reduce_params = {
+                (uint32_t) nrows,                                                                    // nrows
+                (uint32_t) Q->ne[1],                                                                 // seq_len_q
+                (uint32_t) Q->ne[2],                                                                 // n_heads
+                (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),  // offset_dst
+                nwg,                                                                                 // nwg
+                0u,                                                                                  // tmp_data_base
+                (uint32_t) tmp_stats_base,                                                           // tmp_stats_base
+            };
+
+            reduce_entries = {
+                { .binding = 0, .buffer = tmp_buf, .offset = tmp_bind_offset, .size = tmp_size_bytes },
+                { .binding = 1,
+                 .buffer  = ggml_webgpu_tensor_buf(dst),
+                 .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
+                 .size    = ggml_webgpu_tensor_binding_size(ctx, dst) },
+            };
+        }
+
+        const uint64_t split_wg_total = (uint64_t) wg_x * nwg;
+        GGML_ASSERT(split_wg_total <= UINT32_MAX);
+        std::vector<webgpu_pipeline>                   pipelines;
+        std::vector<std::vector<uint32_t>>             params_list;
+        std::vector<std::vector<wgpu::BindGroupEntry>> entries_list;
+        std::vector<std::pair<uint32_t, uint32_t>>     workgroups_list;
+
+        if (use_blk) {
+            pipelines.push_back(blk_pipeline);
+            params_list.push_back(std::move(blk_params));
+            entries_list.push_back(std::move(blk_entries));
+            workgroups_list.push_back({ blk_nblk0, blk_nblk1 * blk_batch_count });
+        }
+        pipelines.push_back(pipeline);
+        params_list.push_back(std::move(split_params));
+        entries_list.push_back(std::move(split_entries));
+        workgroups_list.push_back({ (uint32_t) split_wg_total, 1u });
+        if (use_vec_reduce) {
+            pipelines.push_back(reduce_pipeline);
+            params_list.push_back(std::move(reduce_params));
+            entries_list.push_back(std::move(reduce_entries));
+            workgroups_list.push_back({ (uint32_t) nrows, 1u });
+        }
+
+        return ggml_backend_webgpu_build_multi(ctx->global_ctx, ctx->param_buf_pool, pipelines, params_list,
+                                               entries_list, workgroups_list);
+    }
+
     return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
 }
-#endif
 
 static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
     bool is_unary = dst->op == GGML_OP_UNARY;
@@ -2559,7 +2775,6 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
     std::vector<webgpu_submission> subs;
     uint32_t                       num_batched_kernels = 0;
     bool                           contains_set_rows   = false;
-
     for (int i = 0; i < cgraph->n_nodes; i++) {
         if (cgraph->nodes[i]->op == GGML_OP_SET_ROWS) {
             contains_set_rows = true;
@@ -2834,6 +3049,86 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer
                 }
             }
             break;
+        case GGML_OP_FLASH_ATTN_EXT:
+            {
+                const ggml_tensor * Q     = tensor->src[0];
+                const ggml_tensor * K     = tensor->src[1];
+                const ggml_tensor * V     = tensor->src[2];
+                const ggml_tensor * mask  = tensor->src[3];
+                const ggml_tensor * sinks = tensor->src[4];
+                if (Q && K && V) {
+                    GGML_UNUSED(sinks);
+                    const bool kv_direct = (K->type == GGML_TYPE_F16) &&
+                                           (Q->ne[0] % ctx->webgpu_global_ctx->capabilities.sg_mat_k == 0) &&
+                                           (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
+                    const bool kv_vec_type_supported =
+                        K->type == GGML_TYPE_F16 || K->type == GGML_TYPE_Q4_0 || K->type == GGML_TYPE_Q8_0;
+                    const bool use_vec =
+                        (Q->ne[1] < 20) && (Q->ne[0] % 32 == 0) && (V->ne[0] % 4 == 0) && kv_vec_type_supported &&
+                        (V->type == K->type);
+                    if (use_vec) {
+                        const uint32_t sg_mat_m = ctx->webgpu_global_ctx->capabilities.sg_mat_m;
+                        const uint32_t sg_mat_n = ctx->webgpu_global_ctx->capabilities.sg_mat_n;
+                        const size_t   limit_bytes =
+                            ctx->webgpu_global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize;
+                        const size_t q_tile = sg_mat_m;
+                        const size_t base_q_bytes =
+                            (Q->ne[0] + V->ne[0]) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
+                            2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
+                        size_t bytes_per_kv = 0;
+                        if (!kv_direct) {
+                            bytes_per_kv += std::max(Q->ne[0], V->ne[0]);
+                        }
+                        if (mask != nullptr) {
+                            bytes_per_kv += q_tile;
+                        }
+                        bytes_per_kv += q_tile;
+                        bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;
+                        uint32_t kv_tile =
+                            ((limit_bytes - base_q_bytes) / bytes_per_kv / sg_mat_n) * sg_mat_n;
+                        kv_tile = std::max(sg_mat_n, std::min(32u, kv_tile));
+                        kv_tile = (kv_tile / sg_mat_n) * sg_mat_n;
+                        if (kv_direct) {
+                            GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
+                            while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
+                                kv_tile -= sg_mat_n;
+                            }
+                        }
+
+                        const uint32_t vec_nwg_cap = std::max(
+                            1u, std::min<uint32_t>(32u, ctx->webgpu_global_ctx->capabilities.max_subgroup_size));
+                        uint32_t       nwg         = 1u;
+                        const uint64_t kv_span     = (uint64_t) std::max(1u, kv_tile);
+                        while ((2u * nwg * kv_span) < (uint64_t) K->ne[1] && nwg < vec_nwg_cap) {
+                            nwg <<= 1;
+                        }
+                        nwg = std::min(nwg, vec_nwg_cap);
+
+                        const size_t align = ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment;
+                        const uint64_t nrows = (uint64_t) Q->ne[1] * Q->ne[2] * Q->ne[3];
+                        if (nwg > 1u) {
+                            const uint64_t tmp_data_elems  = nrows * (uint64_t) V->ne[0] * nwg;
+                            const uint64_t tmp_stats_elems = nrows * 2u * nwg;
+                            const size_t tmp_size_bytes = ROUNDUP_POW2(
+                                (tmp_data_elems + tmp_stats_elems) * sizeof(float), WEBGPU_STORAGE_BUF_BINDING_MULT);
+                            res += tmp_size_bytes + align;
+                        }
+                        if (mask != nullptr) {
+                            const uint32_t blk_nblk0 = CEIL_DIV((uint32_t) K->ne[1], kv_tile);
+                            const uint32_t blk_nblk1 = CEIL_DIV((uint32_t) Q->ne[1], 1u);
+                            const uint32_t stride_mask3 =
+                                (uint32_t) (mask->nb[3] / ggml_type_size(mask->type));
+                            const uint32_t blk_batch_count = stride_mask3 > 0 ? (uint32_t) Q->ne[3] : 1u;
+                            const uint64_t blk_elems = (uint64_t) blk_nblk0 * blk_nblk1 * blk_batch_count;
+                            const size_t blk_size_bytes =
+                                ROUNDUP_POW2(blk_elems * sizeof(uint32_t), WEBGPU_STORAGE_BUF_BINDING_MULT);
+                            res += blk_size_bytes + align;
+                        }
+                        res = ROUNDUP_POW2(res, WEBGPU_STORAGE_BUF_BINDING_MULT);
+                    }
+                }
+            }
+            break;
         default:
             break;
     }
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl
new file mode 100644 (file)
index 0000000..82d072b
--- /dev/null
@@ -0,0 +1,105 @@
+diagnostic(off, subgroup_uniformity);
+enable f16;
+
+#define Q_TILE 1
+#define KV_TILE 32
+#define WG_SIZE 32
+
+struct Params {
+    offset_mask: u32,
+    seq_len_q: u32,
+    seq_len_kv: u32,
+    stride_mask3: u32,
+    // Number of KV blocks and Q blocks per batch.
+    // nblk0 = ceil(seq_len_kv / KV_TILE), nblk1 = ceil(seq_len_q / Q_TILE).
+    nblk0: u32,
+    nblk1: u32,
+};
+
+@group(0) @binding(0) var<storage, read> mask: array<f16>;
+@group(0) @binding(1) var<storage, read_write> blk: array<u32>;
+@group(0) @binding(2) var<uniform> params: Params;
+
+const MASK_MIN: f32 = -65504.0;
+const MASK_MAX: f32 = 65504.0;
+var<workgroup> wg_min: array<f32, WG_SIZE>;
+var<workgroup> wg_max: array<f32, WG_SIZE>;
+var<workgroup> wg_any: array<u32, WG_SIZE>;
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
+        @builtin(local_invocation_id) local_id: vec3<u32>) {
+    // Dispatch mapping:
+    //  - x indexes KV blocks
+    //  - y flattens (batch_idx, q_blk) as y = batch_idx * nblk1 + q_blk
+    let kv_blk = wg_id.x;
+    let y = wg_id.y;
+    let q_blk = y % params.nblk1;
+    let batch_idx = y / params.nblk1;
+    if (kv_blk >= params.nblk0) {
+        return;
+    }
+
+    let q_start = q_blk * Q_TILE;
+    let k_start = kv_blk * KV_TILE;
+
+    let mask_batch = select(0u, batch_idx, params.stride_mask3 > 0u);
+    let mask_batch_base = params.offset_mask + mask_batch * params.stride_mask3;
+
+    // We keep min/max to classify:
+    //  - fully masked (max <= MASK_MIN)
+    //  - all-zero mask (min == 0 && max == 0)
+    //  - mixed/general mask
+    var local_min = MASK_MAX;
+    var local_max = -MASK_MAX;
+    var local_any = 0u;
+
+    for (var q_rel = 0u; q_rel < Q_TILE; q_rel += 1u) {
+        let q_row = q_start + q_rel;
+        if (q_row >= params.seq_len_q) {
+            continue;
+        }
+        let row_base = mask_batch_base + q_row * params.seq_len_kv;
+        for (var k_rel = local_id.x; k_rel < KV_TILE; k_rel += WG_SIZE) {
+            let k_col = k_start + k_rel;
+            if (k_col >= params.seq_len_kv) {
+                continue;
+            }
+            let mv = f32(mask[row_base + k_col]);
+            local_min = min(local_min, mv);
+            local_max = max(local_max, mv);
+            local_any = 1u;
+        }
+    }
+
+    wg_min[local_id.x] = local_min;
+    wg_max[local_id.x] = local_max;
+    wg_any[local_id.x] = local_any;
+    workgroupBarrier();
+
+    // Thread 0 writes one state per block.
+    if (local_id.x == 0u) {
+        var mmin = wg_min[0];
+        var mmax = wg_max[0];
+        var many = wg_any[0];
+        for (var i = 1u; i < WG_SIZE; i += 1u) {
+            mmin = min(mmin, wg_min[i]);
+            mmax = max(mmax, wg_max[i]);
+            many = max(many, wg_any[i]);
+        }
+
+        var state = 0u;
+        if (many != 0u) {
+            if (mmax <= MASK_MIN) {
+                state = 0u;
+            } else if (mmin == 0.0 && mmax == 0.0) {
+                state = 2u;
+            } else {
+                state = 1u;
+            }
+        }
+
+        let blk_idx = (batch_idx * params.nblk1 + q_blk) * params.nblk0 + kv_blk;
+        blk[blk_idx] = state;
+    }
+}
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl
new file mode 100644 (file)
index 0000000..9a0de82
--- /dev/null
@@ -0,0 +1,78 @@
+diagnostic(off, subgroup_uniformity);
+enable f16;
+enable subgroups;
+
+// Default values
+#define HEAD_DIM_V 64
+#define WG_SIZE 128
+
+struct Params {
+    nrows: u32,
+    seq_len_q: u32,
+    n_heads: u32,
+    offset_dst: u32,
+    nwg: u32,
+    tmp_data_base: u32,
+    tmp_stats_base: u32,
+};
+
+@group(0) @binding(0) var<storage, read_write> tmp: array<f32>;
+@group(0) @binding(1) var<storage, read_write> dst: array<vec4<f32>>;
+@group(0) @binding(2) var<uniform> params: Params;
+
+const FLOAT_MIN: f32 = -1.0e9;
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
+        @builtin(subgroup_id) subgroup_id: u32,
+        @builtin(num_subgroups) num_subgroups: u32,
+        @builtin(subgroup_size) subgroup_size: u32,
+        @builtin(subgroup_invocation_id) sg_inv_id: u32) {
+    let rid = wg_id.x;
+    if (rid >= params.nrows) {
+        return;
+    }
+
+    let rows_per_batch = params.n_heads * params.seq_len_q;
+    let batch_idx = rid / rows_per_batch;
+    let rem = rid % rows_per_batch;
+    let head_idx = rem / params.seq_len_q;
+    let q_row = rem % params.seq_len_q;
+
+    let dst2_stride = HEAD_DIM_V * params.n_heads;
+    let dst3_stride = dst2_stride * params.seq_len_q;
+    let row_base = params.offset_dst + batch_idx * dst3_stride + q_row * dst2_stride + head_idx * HEAD_DIM_V;
+
+    let thread = sg_inv_id;
+    if (params.nwg > subgroup_size) {
+        return;
+    }
+
+    let stats_base = params.tmp_stats_base + rid * (2u * params.nwg);
+    let active_thread = thread < params.nwg;
+    let si = select(0.0, tmp[stats_base + 2u * thread + 0u], active_thread);
+    let mi = select(FLOAT_MIN, tmp[stats_base + 2u * thread + 1u], active_thread);
+    let m = subgroupMax(mi);
+    let ms = select(0.0, exp(mi - m), active_thread);
+    let s = subgroupAdd(si * ms);
+    let inv_s = select(0.0, 1.0 / s, s != 0.0);
+
+    let row_tmp_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg);
+    for (var elem_base = subgroup_id * 4u; elem_base < HEAD_DIM_V; elem_base += num_subgroups * 4u) {
+        var weighted = vec4<f32>(0.0, 0.0, 0.0, 0.0);
+        if (active_thread) {
+            let src = row_tmp_base + thread * HEAD_DIM_V + elem_base;
+            weighted = vec4<f32>(tmp[src + 0u], tmp[src + 1u], tmp[src + 2u], tmp[src + 3u]) * ms;
+        }
+
+        let sum_x = subgroupAdd(weighted.x);
+        let sum_y = subgroupAdd(weighted.y);
+        let sum_z = subgroupAdd(weighted.z);
+        let sum_w = subgroupAdd(weighted.w);
+
+        if (thread == 0u) {
+            let dst_vec_index = (row_base + elem_base) >> 2u;
+            dst[dst_vec_index] = vec4<f32>(sum_x, sum_y, sum_z, sum_w) * inv_s;
+        }
+    }
+}
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl
new file mode 100644 (file)
index 0000000..a525758
--- /dev/null
@@ -0,0 +1,729 @@
+diagnostic(off, chromium.subgroup_matrix_uniformity);
+diagnostic(off, subgroup_uniformity);
+enable f16;
+enable subgroups;
+enable chromium_experimental_subgroup_matrix;
+
+#ifdef KV_F32
+#define KV_TYPE f32
+#else
+#define KV_TYPE f16
+#endif
+
+#define HEAD_DIM_QK 64
+#define HEAD_DIM_V 64
+
+
+#define SG_MAT_M 8
+#define SG_MAT_N 8
+#define SG_MAT_K 8
+
+#define Q_TILE SG_MAT_M
+#define KV_TILE 16
+#define WG_SIZE 64
+#ifndef VEC_NE
+#define VEC_NE 4u
+#endif
+
+#define KV_BLOCKS (KV_TILE / SG_MAT_N)
+
+#define BLOCK_SIZE 32
+#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE)
+#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE)
+#if defined(KV_Q4_0)
+#define NQ 16
+#define F16_PER_BLOCK 9
+#define WEIGHTS_PER_F16 4
+#elif defined(KV_Q8_0)
+#define NQ 8
+#define F16_PER_BLOCK 17
+#define WEIGHTS_PER_F16 2
+#endif
+#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16)
+
+fn get_byte(value: u32, index: u32) -> u32 {
+    return (value >> (index * 8)) & 0xFF;
+}
+
+fn get_byte_i32(value: u32, index: u32) -> i32 {
+    return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;
+}
+
+struct Params {
+    offset_q: u32,
+    offset_k: u32,
+    offset_v: u32,
+    offset_mask: u32,
+    offset_sinks: u32,
+    offset_dst: u32,
+
+    // shapes of Q/K/V
+    n_heads: u32,
+    seq_len_q: u32,
+    seq_len_kv: u32,
+
+    // strides (in elements)
+    stride_q1: u32,
+    stride_q2: u32,
+    stride_q3: u32,
+    stride_k1: u32,
+    stride_k2: u32,
+    stride_k3: u32,
+    stride_v1: u32,
+    stride_v2: u32,
+    stride_v3: u32,
+    stride_mask3: u32,
+
+    // repeat factors for K/V, e.g., MHA vs. MQA vs. GQA
+    q_per_kv: u32,
+
+    // softmax params
+    scale: f32,
+    max_bias: f32,
+    logit_softcap: f32,
+    n_head_log2: f32,
+    m0: f32,
+    m1: f32,
+
+#ifdef BLK
+    blk_base: u32,
+    blk_nblk0: u32,
+    blk_nblk1: u32,
+#endif
+
+    tmp_data_base: u32,
+    tmp_stats_base: u32,
+    nwg: u32,
+};
+
+@group(0) @binding(0) var<storage, read_write> Q: array<f32>;
+#if defined(KV_Q4_0) || defined(KV_Q8_0)
+@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
+#else
+@group(0) @binding(1) var<storage, read_write> K: array<vec4<KV_TYPE>>;
+#endif
+#if defined(KV_Q4_0) || defined(KV_Q8_0)
+@group(0) @binding(2) var<storage, read_write> V: array<KV_TYPE>;
+#else
+@group(0) @binding(2) var<storage, read_write> V: array<vec4<KV_TYPE>>;
+#endif
+#if defined(MASK) && defined(SINKS)
+@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
+@group(0) @binding(4) var<storage, read_write> sinks: array<f32>;
+#ifdef BLK
+#define BLK_BINDING 5
+#define TMP_BINDING 6
+#define DST_BINDING 7
+#define PARAMS_BINDING 8
+#else
+#define TMP_BINDING 5
+#define DST_BINDING 6
+#define PARAMS_BINDING 7
+#endif
+#elif defined(MASK)
+@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
+#ifdef BLK
+#define BLK_BINDING 4
+#define TMP_BINDING 5
+#define DST_BINDING 6
+#define PARAMS_BINDING 7
+#else
+#define TMP_BINDING 4
+#define DST_BINDING 5
+#define PARAMS_BINDING 6
+#endif
+#elif defined(SINKS)
+@group(0) @binding(3) var<storage, read_write> sinks: array<f32>;
+#define TMP_BINDING 4
+#define DST_BINDING 5
+#define PARAMS_BINDING 6
+#else
+#define TMP_BINDING 3
+#define DST_BINDING 4
+#define PARAMS_BINDING 5
+#endif
+
+#ifdef BLK
+@group(0) @binding(BLK_BINDING) var<storage, read_write> blk: array<u32>;
+#endif
+@group(0) @binding(TMP_BINDING) var<storage, read_write> tmp: array<f32>;
+@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<vec4<f32>>;
+@group(0) @binding(PARAMS_BINDING) var<uniform> params: Params;
+
+// Just a very small float value.
+const FLOAT_MIN: f32 = -1.0e9;
+
+var<workgroup> q_shmem: array<f16, Q_TILE * HEAD_DIM_QK>;
+
+#ifndef KV_DIRECT
+const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V);
+// we can reuse the same shmem for K and V since we only need one at a time
+var<workgroup> kv_shmem: array<f16, kv_shmem_size>;
+#endif
+
+var<workgroup> o_shmem: array<f16, Q_TILE * HEAD_DIM_V>;
+
+#ifdef MASK
+// storage for mask values
+var<workgroup> mask_shmem: array<f16, Q_TILE * KV_TILE>;
+#endif
+
+// note that we reuse the same storage for both since we only need one at a time
+var<workgroup> inter_shmem: array<f16, Q_TILE * KV_TILE>;
+
+// Storage for row max and exp sum during online softmax
+var<workgroup> row_max_shmem: array<f32, Q_TILE>;
+var<workgroup> exp_sum_shmem: array<f32, Q_TILE>;
+var<workgroup> blk_state_wg: u32;
+
+fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32, has_bias: bool, apply_mask: bool) -> f32 {
+    var v = select(FLOAT_MIN,
+                   f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]) * params.scale,
+                   kv_idx < KV_TILE);
+#ifdef LOGIT_SOFTCAP
+    v = params.logit_softcap * tanh(v);
+#endif
+#ifdef MASK
+    if (apply_mask) {
+        var mask_val = select(0.0,f32(mask_shmem[q_tile_row * KV_TILE + kv_idx]), kv_idx < KV_TILE);
+        v += select(mask_val, slope * mask_val, has_bias);
+    }
+#endif
+    return v;
+}
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
+    @builtin(local_invocation_id) local_id: vec3<u32>,
+    @builtin(subgroup_id) subgroup_id: u32,
+    @builtin(subgroup_size) subgroup_size: u32,
+    @builtin(num_subgroups) num_subgroups: u32,
+    @builtin(subgroup_invocation_id) sg_inv_id: u32) {
+
+    // initialize row max for online softmax
+    for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) {
+        row_max_shmem[i] = FLOAT_MIN;
+        exp_sum_shmem[i] = 0.0;
+    }
+
+    for (var i = local_id.x; i < Q_TILE * HEAD_DIM_V; i += WG_SIZE) {
+        o_shmem[i] = 0.0;
+    }
+
+    // workgroups per head/batch
+    let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE;
+    let wg_per_batch = wg_per_head * params.n_heads;
+
+    let dst2_stride = HEAD_DIM_V * params.n_heads;
+    let dst3_stride = dst2_stride * params.seq_len_q;
+
+    let iwg = wg_id.x % params.nwg;
+    let base_wg_id = wg_id.x / params.nwg;
+
+    // batch index
+    let batch_idx = base_wg_id / wg_per_batch;
+    let q_batch_offset = params.offset_q + batch_idx * params.stride_q3;
+    let k_batch_offset = params.offset_k + batch_idx * params.stride_k3;
+    let v_batch_offset = params.offset_v + batch_idx * params.stride_v3;
+    let wg_in_batch = base_wg_id % wg_per_batch;
+
+    // head index
+    let head_idx = wg_in_batch / wg_per_head;
+    let q_head_offset = q_batch_offset + head_idx * params.stride_q2;
+    let k_head_idx = head_idx / params.q_per_kv;
+    let v_head_idx = k_head_idx;
+    let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2;
+    let v_head_offset = v_batch_offset + v_head_idx * params.stride_v2;
+
+    // starting Q row for this workgroup
+    let wg_in_head = wg_in_batch % wg_per_head;
+    let q_row_start = wg_in_head * Q_TILE;
+
+#ifdef MASK
+    // mask offset
+    let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv;
+#endif
+
+    let head = f32(head_idx);
+    let has_bias = params.max_bias > 0.0;
+    let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), has_bias);
+
+    // load q tile into shared memory
+    for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) {
+        let q_row = elem_idx / HEAD_DIM_QK;
+        let q_col = elem_idx % HEAD_DIM_QK;
+        let head_q_row = q_row_start + q_row;
+        let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1;
+        q_shmem[elem_idx] = f16(select(
+            0.0,
+            Q[global_q_row_offset + q_col],
+            head_q_row < params.seq_len_q && q_col < HEAD_DIM_QK));
+    }
+
+    for (var kv_tile = iwg * KV_TILE; kv_tile < params.seq_len_kv; kv_tile += KV_TILE * params.nwg) {
+#ifdef BLK
+        let q_blk = q_row_start / Q_TILE;
+        let kv_blk = kv_tile / KV_TILE;
+        let blk_batch = select(0u, batch_idx, params.stride_mask3 > 0u);
+        let blk_idx = params.blk_base + (blk_batch * params.blk_nblk1 + q_blk) * params.blk_nblk0 + kv_blk;
+        let blk_state_local = blk[blk_idx];
+#else
+        let blk_state_local = 1u;
+#endif
+        if (local_id.x == 0u) {
+            blk_state_wg = blk_state_local;
+        }
+        workgroupBarrier();
+        let blk_state = blk_state_wg;
+        let skip_tile = blk_state == 0u;
+        for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
+            inter_shmem[elem_idx] = f16(0.0);
+        }
+
+      // load k tile into shared memory
+#if defined(KV_Q4_0)
+      for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {
+          let blck_idx = elem_idx / BLOCK_SIZE;
+          let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
+          let k_row = blck_idx / BLOCKS_K;
+          let global_k_row = kv_tile + k_row;
+          let block_k = blck_idx % BLOCKS_K;
+          let row_offset = k_row * HEAD_DIM_QK;
+
+          if (global_k_row < params.seq_len_kv) {
+              let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
+              let base_idx = global_block_idx * F16_PER_BLOCK;
+              let d = K[base_idx];
+              for (var j = 0u; j < F16_PER_THREAD; j += 2) {
+                  let q_0 = K[base_idx + 1u + block_offset + j];
+                  let q_1 = K[base_idx + 1u + block_offset + j + 1];
+                  let q_packed = bitcast<u32>(vec2(q_0, q_1));
+                  for (var k = 0u; k < 4u; k++) {
+                      let q_byte = get_byte(q_packed, k);
+                      let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
+                      let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
+                      let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
+                      kv_shmem[row_offset + idx] = q_lo;
+                      kv_shmem[row_offset + idx + 16u] = q_hi;
+                  }
+              }
+          }
+      }
+#elif defined(KV_Q8_0)
+      for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {
+          let blck_idx = elem_idx / BLOCK_SIZE;
+          let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
+          let k_row = blck_idx / BLOCKS_K;
+          let global_k_row = kv_tile + k_row;
+          let block_k = blck_idx % BLOCKS_K;
+          let row_offset = k_row * HEAD_DIM_QK;
+
+          if (global_k_row < params.seq_len_kv) {
+              let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
+              let base_idx = global_block_idx * F16_PER_BLOCK;
+              let d = K[base_idx];
+              for (var j = 0u; j < F16_PER_THREAD; j += 2) {
+                  let q_0 = K[base_idx + 1u + block_offset + j];
+                  let q_1 = K[base_idx + 1u + block_offset + j + 1];
+                  let q_packed = bitcast<u32>(vec2(q_0, q_1));
+                  for (var k = 0u; k < 4u; k++) {
+                      let q_byte = get_byte_i32(q_packed, k);
+                      let q_val = f16(q_byte) * d;
+                      let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
+                      kv_shmem[row_offset + idx] = q_val;
+                  }
+              }
+          }
+      }
+#elif defined(KV_DIRECT)
+      // Direct global loads for KV
+#else
+      for (var elem_idx = local_id.x * 4u; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * 4u) {
+          let k_row = elem_idx / HEAD_DIM_QK;
+          let k_col = elem_idx % HEAD_DIM_QK;
+          let global_k_row = kv_tile + k_row;
+          let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1;
+          let in_bounds = global_k_row < params.seq_len_kv && (k_col + 3u) < HEAD_DIM_QK;
+          let vec_idx = (global_k_row_offset + k_col) >> 2u;
+          let k4 = select(vec4<KV_TYPE>(0.0), K[vec_idx], in_bounds);
+          kv_shmem[elem_idx + 0u] = f16(k4.x);
+          kv_shmem[elem_idx + 1u] = f16(k4.y);
+          kv_shmem[elem_idx + 2u] = f16(k4.z);
+          kv_shmem[elem_idx + 3u] = f16(k4.w);
+      }
+#endif
+
+      workgroupBarrier();
+
+      // accumulate q block * k block into registers across the entire KV tile
+      if (!skip_tile) {
+        let num_of_threads = subgroup_size / VEC_NE;
+        let tx = sg_inv_id % num_of_threads;
+        let ty = sg_inv_id / num_of_threads;
+          for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) {
+              let global_q_row = q_row_start + q_tile_row;
+              if (global_q_row >= params.seq_len_q) {
+                  continue;
+              }
+              let local_q_row_offset = q_tile_row * HEAD_DIM_QK;
+
+              for (var kv_base : u32 = 0u; kv_base < KV_TILE; kv_base += VEC_NE) {
+                  let kv_idx = kv_base + ty;
+                  var partial_sum: f32 = 0.0;
+                  let kv_valid = kv_idx < KV_TILE && (kv_tile + kv_idx) < params.seq_len_kv;
+                  if (kv_valid) {
+                    for (var i = tx; i < (HEAD_DIM_QK / 4u); i += num_of_threads) {
+                        let q_off = local_q_row_offset + i * 4u;
+
+                        let qv = vec4<f32>(
+                            f32(q_shmem[q_off + 0u]),
+                            f32(q_shmem[q_off + 1u]),
+                            f32(q_shmem[q_off + 2u]),
+                            f32(q_shmem[q_off + 3u]));
+#ifdef KV_DIRECT
+                        let idx = k_head_offset + (kv_tile + kv_idx) * params.stride_k1 + (i * 4u);
+                        let kv = vec4<f32>(K[idx >> 2u]);
+#else
+                        let idx = kv_idx * HEAD_DIM_QK + (i * 4u);
+                        let kv = vec4<f32>(
+                            f32(kv_shmem[idx + 0u]),
+                            f32(kv_shmem[idx + 1u]),
+                            f32(kv_shmem[idx + 2u]),
+                            f32(kv_shmem[idx + 3u]));
+#endif
+                        partial_sum += dot(qv, kv);
+                    }
+                  }
+                  var sum = partial_sum;
+                  // Reduce over tx threads (NL) for this ty stripe.
+                  var tx_delta = num_of_threads >> 1u;
+                  loop {
+                      if (tx_delta == 0u) {
+                          break;
+                      }
+                      let sh = subgroupShuffleDown(sum, tx_delta);
+                      if (tx < tx_delta) {
+                          sum += sh;
+                      }
+                      tx_delta >>= 1u;
+                  }
+
+                  let sum_bcast = subgroupShuffle(sum, num_of_threads * ty);
+                  if (tx == 0u && kv_valid) {
+                      let dst_idx = q_tile_row * KV_TILE + kv_idx;
+                      inter_shmem[dst_idx] = f16(sum_bcast);
+                  }
+              }
+          }
+      }
+
+
+#ifdef MASK
+      let apply_mask = !skip_tile && (blk_state != 2u);
+      if (apply_mask) {
+          // load mask tile into shared memory for this KV block
+          for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
+              let mask_row = elem_idx / KV_TILE;
+              let mask_col = elem_idx % KV_TILE;
+              let global_q_row = q_row_start + mask_row;
+              let global_k_col = kv_tile + mask_col;
+              let mask_in_bounds = global_q_row < params.seq_len_q && global_k_col < params.seq_len_kv;
+              let mask_idx = mask_global_offset + mask_row * params.seq_len_kv + global_k_col;
+              mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds);
+          }
+      }
+#else
+      let apply_mask = false;
+#endif
+
+      workgroupBarrier();
+
+      // online softmax
+      if (!skip_tile) {
+          for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) {
+              let global_q_row = q_row_start + q_tile_row;
+              if (global_q_row >= params.seq_len_q) {
+                  break;
+              }
+
+              var prev_max = row_max_shmem[q_tile_row];
+              var final_max = prev_max;
+              // pass 1: compute final max across the full KV tile in chunks
+              for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
+                  let kv_idx = kv_offset + sg_inv_id;
+                  let kv_valid = kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE;
+                  let softmax_term = select(FLOAT_MIN,
+                                            calc_softmax_term(kv_idx, q_tile_row, slope, has_bias, apply_mask),
+                                            kv_valid);
+                  final_max = subgroupMax(max(final_max, softmax_term));
+              }
+
+              var total_exp_term: f32 = 0.0;
+              // pass 2: compute exp sum and write P using final_max
+              for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
+                  let kv_idx = kv_offset + sg_inv_id;
+                  let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope, has_bias, apply_mask);
+                  let cur_p = select(0.0,
+                                     exp(softmax_term - final_max),
+                                     kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE);
+                  total_exp_term += subgroupAdd(cur_p);
+                  if (kv_idx < KV_TILE) {
+                      inter_shmem[kv_idx + q_tile_row * KV_TILE] = f16(cur_p);
+                  }
+              }
+
+              let cur_exp = exp(prev_max - final_max);
+
+              if (sg_inv_id == 0) {
+                  row_max_shmem[q_tile_row] = final_max;
+                  exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * cur_exp + total_exp_term;
+              }
+
+              for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
+                  let idx = q_tile_row * HEAD_DIM_V + elem_idx;
+                  o_shmem[idx] = f16(f32(o_shmem[idx]) * cur_exp);
+              }
+          }
+      }
+
+      // load v tile into shared memory
+#if defined(KV_Q4_0)
+      for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {
+          let blck_idx = elem_idx / BLOCK_SIZE;
+          let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
+          let v_row = blck_idx / BLOCKS_V;
+          let global_v_row = kv_tile + v_row;
+          let block_k = blck_idx % BLOCKS_V;
+          let row_offset = v_row * HEAD_DIM_V;
+
+          if (global_v_row < params.seq_len_kv) {
+              let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
+              let base_idx = global_block_idx * F16_PER_BLOCK;
+              let d = V[base_idx];
+              for (var j = 0u; j < F16_PER_THREAD; j += 2) {
+                  let q_0 = V[base_idx + 1u + block_offset + j];
+                  let q_1 = V[base_idx + 1u + block_offset + j + 1];
+                  let q_packed = bitcast<u32>(vec2(q_0, q_1));
+                  for (var k = 0u; k < 4u; k++) {
+                      let q_byte = get_byte(q_packed, k);
+                      let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
+                      let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
+                      let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
+                      kv_shmem[row_offset + idx] = q_lo;
+                      kv_shmem[row_offset + idx + 16u] = q_hi;
+                  }
+              }
+          }
+      }
+#elif defined(KV_Q8_0)
+      for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {
+          let blck_idx = elem_idx / BLOCK_SIZE;
+          let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
+          let v_row = blck_idx / BLOCKS_V;
+          let global_v_row = kv_tile + v_row;
+          let block_k = blck_idx % BLOCKS_V;
+          let row_offset = v_row * HEAD_DIM_V;
+
+          if (global_v_row < params.seq_len_kv) {
+              let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
+              let base_idx = global_block_idx * F16_PER_BLOCK;
+              let d = V[base_idx];
+              for (var j = 0u; j < F16_PER_THREAD; j += 2) {
+                  let q_0 = V[base_idx + 1u + block_offset + j];
+                  let q_1 = V[base_idx + 1u + block_offset + j + 1];
+                  let q_packed = bitcast<u32>(vec2(q_0, q_1));
+                  for (var k = 0u; k < 4u; k++) {
+                      let q_byte = get_byte_i32(q_packed, k);
+                      let q_val = f16(q_byte) * d;
+                      let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
+                      kv_shmem[row_offset + idx] = q_val;
+                  }
+              }
+          }
+      }
+#elif defined(KV_DIRECT)
+      // Direct global loads for KV
+#else
+      for (var elem_idx = local_id.x * 4u; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * 4u) {
+          let v_row = elem_idx / HEAD_DIM_V;
+          let v_col = elem_idx % HEAD_DIM_V;
+          let global_v_row = kv_tile + v_row;
+          let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1;
+          let in_bounds = global_v_row < params.seq_len_kv && (v_col + 3u) < HEAD_DIM_V;
+          let vec_idx = (global_v_row_offset + v_col) >> 2u;
+          let v4 = select(vec4<KV_TYPE>(0.0), V[vec_idx], in_bounds);
+          kv_shmem[elem_idx + 0u] = f16(v4.x);
+          kv_shmem[elem_idx + 1u] = f16(v4.y);
+          kv_shmem[elem_idx + 2u] = f16(v4.z);
+          kv_shmem[elem_idx + 3u] = f16(v4.w);
+      }
+#endif
+
+      workgroupBarrier();
+
+      if (!skip_tile) {
+          // we have P (Q_TILE x KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem
+          // we want to compute O += P * V across the full KV tile
+          let ne_threads : u32 = VEC_NE;
+          let nl_threads = max(1u, subgroup_size / ne_threads);
+          let tx_pv = sg_inv_id % nl_threads;
+          let ty_pv = sg_inv_id / nl_threads;
+          for (var q_tile_row = subgroup_id;
+               q_tile_row < Q_TILE;
+               q_tile_row += num_subgroups) {
+              for (var vec_col = tx_pv; vec_col < (HEAD_DIM_V / 4u); vec_col += nl_threads) {
+                  var lo = vec4<f32>(0.0, 0.0, 0.0, 0.0);
+                  for (var cc = 0u; cc < KV_TILE / ne_threads; cc += 1u) {
+                      let kv_idx = cc * ne_threads + ty_pv;
+                      let v_row = kv_tile + kv_idx;
+                      if (v_row >= params.seq_len_kv) {
+                          continue;
+                      }
+
+                      let p = f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]);
+#ifdef KV_DIRECT
+                      let v_idx = v_head_offset + v_row * params.stride_v1 + vec_col * 4u;
+                      let v4 = vec4<f32>(V[v_idx >> 2u]);
+#else
+                      let v_idx = kv_idx * HEAD_DIM_V + vec_col * 4u;
+                      let v4 = vec4<f32>(
+                          f32(kv_shmem[v_idx + 0u]),
+                          f32(kv_shmem[v_idx + 1u]),
+                          f32(kv_shmem[v_idx + 2u]),
+                          f32(kv_shmem[v_idx + 3u]));
+#endif
+                      lo += p * v4;
+                  }
+
+                  var lo_x = lo.x;
+                  var lo_y = lo.y;
+                  var lo_z = lo.z;
+                  var lo_w = lo.w;
+                  // Reduce over ty threads (NE) for this tx thread.
+                  var ty_delta = ne_threads >> 1u;
+                  loop {
+                      if (ty_delta == 0u) {
+                          break;
+                      }
+                      let thread_delta = ty_delta * nl_threads;
+                      let shx = subgroupShuffleDown(lo_x, thread_delta);
+                      let shy = subgroupShuffleDown(lo_y, thread_delta);
+                      let shz = subgroupShuffleDown(lo_z, thread_delta);
+                      let shw = subgroupShuffleDown(lo_w, thread_delta);
+                      if (ty_pv < ty_delta) {
+                          lo_x += shx;
+                          lo_y += shy;
+                          lo_z += shz;
+                          lo_w += shw;
+                      }
+                      ty_delta >>= 1u;
+                  }
+
+                  if (ty_pv == 0u) {
+                      let elem_base = vec_col * 4u;
+                      let o_base_idx = q_tile_row * HEAD_DIM_V + elem_base;
+                      o_shmem[o_base_idx + 0u] = f16(f32(o_shmem[o_base_idx + 0u]) + lo_x);
+                      o_shmem[o_base_idx + 1u] = f16(f32(o_shmem[o_base_idx + 1u]) + lo_y);
+                      o_shmem[o_base_idx + 2u] = f16(f32(o_shmem[o_base_idx + 2u]) + lo_z);
+                      o_shmem[o_base_idx + 3u] = f16(f32(o_shmem[o_base_idx + 3u]) + lo_w);
+                  }
+              }
+          }
+      }
+
+        workgroupBarrier();
+    }
+
+
+#ifdef SINKS
+    // Sinks are global terms and must be applied exactly once across split workgroups.
+    if (iwg == 0u) {
+        for (var q_tile_row = subgroup_id;
+             q_tile_row < Q_TILE;
+             q_tile_row += num_subgroups) {
+                let global_q_row = q_row_start + q_tile_row;
+                if (global_q_row >= params.seq_len_q) {
+                    break;
+                }
+
+                var prev_max = row_max_shmem[q_tile_row];
+
+                // for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum
+                let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0);
+                let new_max = subgroupMax(max(prev_max, sink_val));
+                let max_exp = exp(prev_max - new_max);
+                let sink_exp = exp(sink_val - new_max);
+
+                let sink_exp_sum = subgroupAdd(sink_exp);
+
+                if (sg_inv_id == 0) {
+                    row_max_shmem[q_tile_row] = new_max;
+                    exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * max_exp + sink_exp_sum;
+                }
+
+            for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
+                let idx = q_tile_row * HEAD_DIM_V + elem_idx;
+                o_shmem[idx] = f16(f32(o_shmem[idx]) * max_exp);
+            }
+        }
+        workgroupBarrier();
+    }
+#endif
+    let rows_per_batch = params.n_heads * params.seq_len_q;
+    for (var q_tile_row = subgroup_id;
+         q_tile_row < Q_TILE;
+         q_tile_row += num_subgroups) {
+
+        let global_q_row = q_row_start + q_tile_row;
+        if (global_q_row >= params.seq_len_q) { break; }
+
+        if (params.nwg == 1u) {
+            let exp_sum = exp_sum_shmem[q_tile_row];
+            let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0.0);
+            let row_base: u32 =
+                params.offset_dst + batch_idx * dst3_stride + global_q_row * dst2_stride + head_idx * HEAD_DIM_V;
+
+            for (var elem_base = sg_inv_id * 4u; elem_base < HEAD_DIM_V; elem_base += subgroup_size * 4u) {
+                let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u);
+                let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u);
+                let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u);
+                let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u);
+
+                let v = vec4<f32>(
+                    f32(o_shmem[i0]) * scale,
+                    f32(o_shmem[i1]) * scale,
+                    f32(o_shmem[i2]) * scale,
+                    f32(o_shmem[i3]) * scale
+                );
+
+                let dst_vec_index: u32 = (row_base + elem_base) >> 2u;
+                dst[dst_vec_index] = v;
+            }
+        } else {
+            let rid = batch_idx * rows_per_batch + head_idx * params.seq_len_q + global_q_row;
+            let tmp_row_data_base = params.tmp_data_base + rid * (HEAD_DIM_V * params.nwg) + iwg * HEAD_DIM_V;
+            let tmp_row_stats_base = params.tmp_stats_base + rid * (2u * params.nwg) + 2u * iwg;
+
+            for (var elem_base = sg_inv_id * 4u;
+                elem_base < HEAD_DIM_V;
+                elem_base += subgroup_size * 4u) {
+
+                let i0 = q_tile_row * HEAD_DIM_V + (elem_base + 0u);
+                let i1 = q_tile_row * HEAD_DIM_V + (elem_base + 1u);
+                let i2 = q_tile_row * HEAD_DIM_V + (elem_base + 2u);
+                let i3 = q_tile_row * HEAD_DIM_V + (elem_base + 3u);
+
+                let tbase = tmp_row_data_base + elem_base;
+                tmp[tbase + 0u] = f32(o_shmem[i0]);
+                tmp[tbase + 1u] = f32(o_shmem[i1]);
+                tmp[tbase + 2u] = f32(o_shmem[i2]);
+                tmp[tbase + 3u] = f32(o_shmem[i3]);
+            }
+
+            if (sg_inv_id == 0u) {
+                tmp[tmp_row_stats_base + 0u] = exp_sum_shmem[q_tile_row];
+                tmp[tmp_row_stats_base + 1u] = row_max_shmem[q_tile_row];
+            }
+        }
+    }
+}