]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml webgpu: initial flashattention implementation (llama/18610)
authorReese Levine <redacted>
Thu, 8 Jan 2026 16:23:39 +0000 (08:23 -0800)
committerGeorgi Gerganov <redacted>
Wed, 14 Jan 2026 07:11:59 +0000 (09:11 +0200)
* FlashAttention (llama/13)

* Add inplace softmax

* Move rms_norm to split row approach

* Update debug for supports_op

* clean up debug statements

* neg f16xf32xip builds and runs, havent actually ran a model that uses neg kernel yet though

* neg passes backend test

* unary operators pass ggml tests

* rms_norm double declaration bug atoned

* abides by editor-config

* removed vestigial files

* fixed autoconfig

* All operators (inlcluding xielu) working

* removed unnecesarry checking if node->src[1] exists for unary operators

* responded and dealt with PR comments

* implemented REPL_Template support and removed bug in unary operators kernel

* formatted embed wgsl and ggml-webgpu.cpp

* Faster tensors (llama/8)

Add fast matrix and matrix/vector multiplication.

* Use map for shader replacements instead of pair of strings

* Wasm (llama/9)

* webgpu : fix build on emscripten

* more debugging stuff

* test-backend-ops: force single thread on wasm

* fix single-thread case for init_tensor_uniform

* use jspi

* add pthread

* test: remember to set n_thread for cpu backend

* Add buffer label and enable dawn-specific toggles to turn off some checks

* Intermediate state

* Fast working f16/f32 vec4

* Working float fast mul mat

* Clean up naming of mul_mat to match logical model, start work on q mul_mat

* Setup for subgroup matrix mat mul

* Basic working subgroup matrix

* Working subgroup matrix tiling

* Handle weirder sg matrix sizes (but still % sg matrix size)

* Working start to gemv

* working f16 accumulation with shared memory staging

* Print out available subgroup matrix configurations

* Vectorize dst stores for sg matrix shader

* Gemv working scalar

* Minor set_rows optimization (llama/4)

* updated optimization, fixed errors

* non vectorized version now dispatches one thread per element

* Simplify

* Change logic for set_rows pipelines

---------

Co-authored-by: Neha Abbas <redacted>
Co-authored-by: Neha Abbas <redacted>
Co-authored-by: Reese Levine <redacted>
* Comment on dawn toggles

* Working subgroup matrix code for (semi)generic sizes

* Remove some comments

* Cleanup code

* Update dawn version and move to portable subgroup size

* Try to fix new dawn release

* Update subgroup size comment

* Only check for subgroup matrix configs if they are supported

* Add toggles for subgroup matrix/f16 support on nvidia+vulkan

* Make row/col naming consistent

* Refactor shared memory loading

* Move sg matrix stores to correct file

* Working q4_0

* Formatting

* Work with emscripten builds

* Fix test-backend-ops emscripten for f16/quantized types

* Use emscripten memory64 to support get_memory

* Add build flags and try ci

---------

Co-authored-by: Xuan Son Nguyen <redacted>
* Remove extra whitespace

* Move wasm single-thread logic out of test-backend-ops for cpu backend

* Disable multiple threads for emscripten single-thread builds in ggml_graph_plan

* Refactored pipelines and workgroup calculations (llama/10)

* refactored pipelines

* refactored workgroup calculation

* removed commented out block of prior maps

* Clean up ceiling division pattern

---------

Co-authored-by: Neha Abbas <redacted>
Co-authored-by: Reese Levine <redacted>
* Start work on flash attention

* Shader structure set up (many bugs still)

* debugging

* Working first test

* Working with head grouping, head sizes to 128, logit softcap, mask/sinks enabled, f32

* Generalize softmax to work with multiple subgroups, f16 accumulation, mask shared memory tiling

* Start work on integrating pre-wgsl

* Separate structs/initial shader compilation library into separate files

* Work on compilation choices for flashattention

* Work on subgroup matrix/tile size portability

* subgroup size agnostic online softmax

* Cleanups, quantization types

* more cleanup

* fix wasm build

* Refactor flashattention to increase parallelism, use direct loads for KV in somce cases

* Checkpoint

* formatting

* Update to account for default kv cache padding

* formatting shader

* Add workflow for ggml-ci webgpu

* Try passing absolute path to dawn in ggml-ci

* Avoid error on device destruction, add todos for proper cleanup

* Fix unused warning

* Forgot one parameter unused

* Move some flashattn computation to f32 for correctness

ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp [new file with mode: 0644]
ggml/src/ggml-webgpu/ggml-webgpu.cpp
ggml/src/ggml-webgpu/pre_wgsl.hpp [new file with mode: 0644]
ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl [new file with mode: 0644]

diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp
new file mode 100644 (file)
index 0000000..7fdb4c8
--- /dev/null
@@ -0,0 +1,169 @@
+#ifndef GGML_WEBGPU_SHADER_LIB_HPP
+#define GGML_WEBGPU_SHADER_LIB_HPP
+
+#include "ggml.h"
+#include "pre_wgsl.hpp"
+
+#include <string>
+#include <vector>
+
+#define GGML_WEBGPU_F16_SIZE_BYTES                   2
+#define GGML_WEBGPU_F32_SIZE_BYTES                   4
+#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES 8u
+#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE     128u
+// Matches GGML_PAD(..., 256) in src/llama-context.cpp for KV cache sizing.
+#define GGML_WEBGPU_KV_SEQ_PAD                       256u
+
+struct ggml_webgpu_flash_attn_shader_lib_context {
+    ggml_type kv_type;
+    uint32_t  head_dim_qk;
+    uint32_t  head_dim_v;
+    bool      kv_direct;
+    bool      has_mask;
+    bool      has_sinks;
+    bool      uses_logit_softcap;
+    uint32_t  sg_mat_m;
+    uint32_t  sg_mat_n;
+    uint32_t  sg_mat_k;
+    size_t    wg_mem_limit_bytes;
+    uint32_t  max_subgroup_size;
+};
+
+struct ggml_webgpu_flash_attn_shader_decisions {
+    uint32_t q_tile  = 0;
+    uint32_t kv_tile = 0;
+    uint32_t wg_size = 0;
+};
+
+struct ggml_webgpu_processed_shader {
+    std::string                             wgsl;
+    std::string                             variant;
+    ggml_webgpu_flash_attn_shader_decisions decisions;
+};
+
+// This is exposed because it's necessary in supports_op
+inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
+                                                  uint32_t kv_tile,
+                                                  uint32_t head_dim_qk,
+                                                  uint32_t head_dim_v,
+                                                  bool     has_mask,
+                                                  bool     kv_direct) {
+    const uint32_t max_head_dim = std::max(head_dim_qk, head_dim_v);
+    size_t         f16_elems    = 0;
+    size_t         f32_elems    = 0;
+    f16_elems += q_tile * head_dim_qk;        // q_shmem
+    if (!kv_direct) {
+        f16_elems += kv_tile * max_head_dim;  // kv_shmem
+    }
+    f16_elems += q_tile * head_dim_v;         // o_shmem
+    if (has_mask) {
+        f16_elems += q_tile * kv_tile;        // mask_shmem
+    }
+    f16_elems += q_tile * kv_tile;            // inter_shmem
+    f32_elems += q_tile;                      // row_max_shmem
+    f32_elems += q_tile;                      // exp_sum_shmem
+    return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES;
+}
+
+static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) {
+    const size_t limit_bytes  = context.wg_mem_limit_bytes;
+    const size_t q_tile       = context.sg_mat_m;
+    const size_t base_q_bytes = (context.head_dim_qk + context.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
+                                2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
+    size_t bytes_per_kv = 0;
+    if (!context.kv_direct) {
+        bytes_per_kv += std::max(context.head_dim_qk, context.head_dim_v);
+    }
+    if (context.has_mask) {
+        bytes_per_kv += q_tile;
+    }
+    bytes_per_kv += q_tile;
+    bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;
+    const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv;
+    return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n;
+}
+
+inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader(
+    pre_wgsl::Preprocessor &                          preprocessor,
+    const char *                                      shader_src,
+    const ggml_webgpu_flash_attn_shader_lib_context & context) {
+    std::vector<std::string> defines;
+    std::string              variant = "flash_attn";
+
+    switch (context.kv_type) {
+        case GGML_TYPE_F32:
+            defines.push_back("KV_F32");
+            break;
+        case GGML_TYPE_F16:
+            defines.push_back("KV_F16");
+            break;
+        case GGML_TYPE_Q4_0:
+            defines.push_back("KV_Q4_0");
+            break;
+        case GGML_TYPE_Q8_0:
+            defines.push_back("KV_Q8_0");
+            break;
+        default:
+            GGML_ABORT("Unsupported KV type for flash attention shader");
+    }
+    variant += std::string("_") + ggml_type_name(context.kv_type);
+
+    if (context.has_mask) {
+        defines.push_back("MASK");
+        variant += "_mask";
+    }
+    if (context.has_sinks) {
+        defines.push_back("SINKS");
+        variant += "_sinks";
+    }
+    if (context.uses_logit_softcap) {
+        defines.push_back("LOGIT_SOFTCAP");
+        variant += "_lgsc";
+    }
+
+    if (context.kv_direct) {
+        defines.push_back("KV_DIRECT");
+        variant += "_kvdirect";
+    }
+
+    defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.head_dim_qk));
+    variant += std::string("_hsqk") + std::to_string(context.head_dim_qk);
+
+    defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.head_dim_v));
+    variant += std::string("_hsv") + std::to_string(context.head_dim_v);
+
+    // For now these are not part of the variant name
+    defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
+    defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
+    defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
+
+    // Add chosen Q/KV tile sizes
+    uint32_t q_tile  = context.sg_mat_m;
+    uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context),
+                                context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
+    if (context.kv_direct) {
+        GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
+        // Avoids having to use bounds-checks and decreasing performance for direct KV loads
+        while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
+            kv_tile -= context.sg_mat_n;
+        }
+    }
+
+    defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile));
+    defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile));
+
+    // workgroup size
+    uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
+
+    defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
+
+    ggml_webgpu_processed_shader result;
+    result.wgsl              = preprocessor.preprocess(shader_src, defines);
+    result.variant           = variant;
+    result.decisions.q_tile  = q_tile;
+    result.decisions.kv_tile = kv_tile;
+    result.decisions.wg_size = wg_size;
+    return result;
+}
+
+#endif  // GGML_WEBGPU_SHADER_LIB_HPP
index c7afdfb8e92f3700d51052f1591fdfbf90bdac94..f64f94b96f04f3bb2e4213adf79c6f001c4a6011 100644 (file)
@@ -7,7 +7,9 @@
 
 #include "ggml-backend-impl.h"
 #include "ggml-impl.h"
+#include "ggml-webgpu-shader-lib.hpp"
 #include "ggml-wgsl-shaders.hpp"
+#include "pre_wgsl.hpp"
 
 #ifdef __EMSCRIPTEN__
 #    include <emscripten/emscripten.h>
@@ -30,7 +32,7 @@
 
 #ifdef GGML_WEBGPU_DEBUG
 #    define WEBGPU_LOG_DEBUG(msg)  std::cout << msg << std::endl
-#    define WEBGPU_DEBUG_BUF_ELEMS 32
+#    define WEBGPU_DEBUG_BUF_ELEMS 512
 #else
 #    define WEBGPU_LOG_DEBUG(msg) ((void) 0)
 #endif  // GGML_WEBGPU_DEBUG
@@ -251,6 +253,7 @@ struct webgpu_gpu_profile_buf_pool {
 struct webgpu_pipeline {
     wgpu::ComputePipeline pipeline;
     std::string           name;
+    void *                context = nullptr;
 };
 
 struct webgpu_command {
@@ -263,6 +266,46 @@ struct webgpu_command {
 #endif
 };
 
+struct flash_attn_pipeline_key {
+    int      q_type;
+    int      kv_type;
+    int      dst_type;
+    uint32_t head_dim_qk;
+    uint32_t head_dim_v;
+    bool     kv_direct;
+    bool     has_mask;
+    bool     has_sinks;
+    bool     uses_logit_softcap;
+
+    bool operator==(const flash_attn_pipeline_key & other) const {
+        return q_type == other.q_type && kv_type == other.kv_type && dst_type == other.dst_type &&
+               head_dim_qk == other.head_dim_qk && head_dim_v == other.head_dim_v && kv_direct == other.kv_direct &&
+               has_mask == other.has_mask && has_sinks == other.has_sinks &&
+               uses_logit_softcap == other.uses_logit_softcap;
+    }
+};
+
+// Same hash combine function as in boost
+template <typename T> inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) {
+    seed ^= std::hash<T>{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
+}
+
+struct flash_attn_pipeline_key_hash {
+    size_t operator()(const flash_attn_pipeline_key & key) const {
+        size_t seed = 0;
+        ggml_webgpu_hash_combine(seed, key.q_type);
+        ggml_webgpu_hash_combine(seed, key.kv_type);
+        ggml_webgpu_hash_combine(seed, key.dst_type);
+        ggml_webgpu_hash_combine(seed, key.head_dim_qk);
+        ggml_webgpu_hash_combine(seed, key.head_dim_v);
+        ggml_webgpu_hash_combine(seed, key.kv_direct);
+        ggml_webgpu_hash_combine(seed, key.has_mask);
+        ggml_webgpu_hash_combine(seed, key.has_sinks);
+        ggml_webgpu_hash_combine(seed, key.uses_logit_softcap);
+        return seed;
+    }
+};
+
 // All the base objects needed to run operations on a WebGPU device
 struct webgpu_context_struct {
     wgpu::Instance instance;
@@ -271,12 +314,12 @@ struct webgpu_context_struct {
     wgpu::Queue    queue;
     wgpu::Limits   limits;
 
-    uint32_t subgroup_size;
+    uint32_t max_subgroup_size;
 
-#ifndef __EMSCRIPTEN__
-    bool                       supports_subgroup_matrix = false;
-    wgpu::SubgroupMatrixConfig subgroup_matrix_config;
-#endif
+    bool     supports_subgroup_matrix = false;
+    uint32_t sg_mat_m;
+    uint32_t sg_mat_n;
+    uint32_t sg_mat_k;
 
     std::recursive_mutex mutex;
     std::atomic_uint     inflight_threads = 0;
@@ -284,20 +327,24 @@ struct webgpu_context_struct {
     webgpu_buf_pool param_buf_pool;
     webgpu_buf_pool set_rows_error_buf_pool;
 
+    pre_wgsl::Preprocessor p;
+
     std::map<int, webgpu_pipeline> memset_pipelines;                                 // variant or type index
 
     std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> mul_mat_pipelines;  // src0_type, src1_type, vectorized
     std::map<int, std::map<int, std::map<int, webgpu_pipeline>>>
         mul_mat_vec_pipelines;                                                       // src0_type, src1_type, vectorized
 
-    std::map<int, std::map<int, webgpu_pipeline>> set_rows_pipelines;                // dst_type, vectorized
-    std::map<int, std::map<int, webgpu_pipeline>> get_rows_pipelines;                // src_type, vectorized
+    std::unordered_map<flash_attn_pipeline_key, webgpu_pipeline, flash_attn_pipeline_key_hash> flash_attn_pipelines;
+
+    std::map<int, std::map<int, webgpu_pipeline>> set_rows_pipelines;                 // dst_type, vectorized
+    std::map<int, std::map<int, webgpu_pipeline>> get_rows_pipelines;                 // src_type, vectorized
 
-    std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines;                     // src_type, dst_type
-    std::map<int, std::map<int, webgpu_pipeline>> add_pipelines;                     // type, inplace
-    std::map<int, std::map<int, webgpu_pipeline>> sub_pipelines;                     // type, inplace
-    std::map<int, std::map<int, webgpu_pipeline>> mul_pipelines;                     // type, inplace
-    std::map<int, std::map<int, webgpu_pipeline>> div_pipelines;                     // type, inplace
+    std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines;                      // src_type, dst_type
+    std::map<int, std::map<int, webgpu_pipeline>> add_pipelines;                      // type, inplace
+    std::map<int, std::map<int, webgpu_pipeline>> sub_pipelines;                      // type, inplace
+    std::map<int, std::map<int, webgpu_pipeline>> mul_pipelines;                      // type, inplace
+    std::map<int, std::map<int, webgpu_pipeline>> div_pipelines;                      // type, inplace
 
     std::map<int, webgpu_pipeline>                               rms_norm_pipelines;  // inplace
     std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines;      // type, ff, inplace
@@ -361,8 +408,6 @@ struct ggml_backend_webgpu_buffer_context {
         label(std::move(lbl)) {}
 };
 
-/* End struct definitions */
-
 /* WebGPU object initializations */
 
 // Process a WGSL shader string, replacing tokens of the form {{KEY}} with
@@ -484,14 +529,9 @@ static void ggml_backend_webgpu_debug(webgpu_context & ctx) {
     encoder.CopyBufferToBuffer(ctx->debug_dev_buf, 0, ctx->debug_host_buf, 0, ctx->debug_host_buf.GetSize());
     wgpu::CommandBuffer commands = encoder.Finish();
     ctx->queue.Submit(1, &commands);
-
     ggml_backend_webgpu_map_buffer(ctx, ctx->debug_host_buf, wgpu::MapMode::Read, 0, ctx->debug_host_buf.GetSize());
-    const uint32_t * debug_data = (const uint32_t *) ctx->debug_host_buf.GetConstMappedRange();
-    std::cout << "debug data:";
-    for (size_t i = 0; i < WEBGPU_DEBUG_BUF_ELEMS; i++) {
-        std::cout << "  " << i << ": " << debug_data[i];
-    }
-    std::cout << "\n";
+    const float * debug_data = (const float *) ctx->debug_host_buf.GetConstMappedRange();
+    std::cout << "debug[0]: " << debug_data[0] << "\n";
     ctx->debug_host_buf.Unmap();
 }
 #endif
@@ -673,6 +713,7 @@ static const char * ggml_backend_webgpu_name(ggml_backend_t backend) {
     return ctx->name.c_str();
 }
 
+// TODO: implement proper cleanup
 static void ggml_backend_webgpu_free(ggml_backend_t backend) {
     ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;
     WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")");
@@ -730,12 +771,12 @@ static wgpu::Buffer ggml_webgpu_tensor_buf(const ggml_tensor * tensor) {
     return ctx->buffer;
 }
 
-static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, ggml_tensor * t) {
+static size_t ggml_webgpu_tensor_misalignment(webgpu_context & ctx, const ggml_tensor * t) {
     size_t offset = ggml_webgpu_tensor_offset(t);
     return offset & (ctx->limits.minStorageBufferOffsetAlignment - 1);
 }
 
-static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, ggml_tensor * t) {
+static size_t ggml_webgpu_tensor_align_offset(webgpu_context & ctx, const ggml_tensor * t) {
     size_t offset = ggml_webgpu_tensor_offset(t);
     return offset & ~(ctx->limits.minStorageBufferOffsetAlignment - 1);
 }
@@ -964,12 +1005,10 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
 #ifndef __EMSCRIPTEN__
             if (ctx->supports_subgroup_matrix) {
                 // The total number of subgroups/workgroups needed per matrix.
-                uint32_t wg_m_sg_tile =
-                    WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->subgroup_matrix_config.M;
-                wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile);
-                uint32_t wg_n_sg_tile =
-                    WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->subgroup_matrix_config.N;
-                wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile);
+                uint32_t wg_m_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * ctx->sg_mat_m;
+                wg_m                  = CEIL_DIV(dst->ne[0], wg_m_sg_tile);
+                uint32_t wg_n_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * ctx->sg_mat_n;
+                wg_n                  = CEIL_DIV(dst->ne[1], wg_n_sg_tile);
             } else {
 #endif
                 uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M;
@@ -986,6 +1025,146 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx,
     return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x, wg_y);
 }
 
+static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
+                                             ggml_tensor *    Q,
+                                             ggml_tensor *    K,
+                                             ggml_tensor *    V,
+                                             ggml_tensor *    mask,
+                                             ggml_tensor *    sinks,
+                                             ggml_tensor *    dst) {
+    float scale = *(float *) dst->op_params;
+    float max_bias;
+    memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
+    float logit_softcap;
+    memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
+    if (logit_softcap != 0.0f) {
+        scale /= logit_softcap;
+    }
+    float n_head_log2 = float(1u << (uint32_t) floor(log2(Q->ne[2])));
+    float m0          = powf(2.0f, -(max_bias) / n_head_log2);
+    float m1          = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+    const int has_mask  = (mask != nullptr);
+    const int has_sinks = (sinks != nullptr);
+
+    std::vector<uint32_t> params = {
+        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, Q) / ggml_type_size(Q->type)),
+        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, K) / ggml_type_size(K->type)),
+        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, V) / ggml_type_size(V->type)),
+        has_mask ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, mask) / ggml_type_size(mask->type)) : 0,
+        has_sinks ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, sinks) / ggml_type_size(sinks->type)) : 0,
+        (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
+        (uint32_t) Q->ne[2],                              // number of heads
+        (uint32_t) Q->ne[1],                              // sequence length (Q)
+        (uint32_t) K->ne[1],                              // sequence length (K/V)
+        (uint32_t) (Q->nb[1] / ggml_type_size(Q->type)),  // stride (elements/blocks) of Q in dimension 1
+        (uint32_t) (Q->nb[2] / ggml_type_size(Q->type)),  // stride (elements/blocks) of Q in dimension 2
+        (uint32_t) (Q->nb[3] / ggml_type_size(Q->type)),  // stride (elements/blocks) of Q in dimension 3
+        (uint32_t) (K->nb[1] / ggml_type_size(K->type)),  // stride (elements/blocks) of K in dimension 1
+        (uint32_t) (K->nb[2] / ggml_type_size(K->type)),  // stride (elements/blocks) of K in dimension 2
+        (uint32_t) (K->nb[3] / ggml_type_size(K->type)),  // stride (elements/blocks) of K in dimension 3
+        (uint32_t) (V->nb[1] / ggml_type_size(V->type)),  // stride (elements/blocks) of V in dimension 1
+        (uint32_t) (V->nb[2] / ggml_type_size(V->type)),  // stride (elements/blocks) of V in dimension 2
+        (uint32_t) (V->nb[3] / ggml_type_size(V->type)),  // stride (elements/blocks) of V in dimension 3
+        has_mask ? (uint32_t) (mask->nb[3] / ggml_type_size(mask->type)) : 0,  // stride of mask dim 3
+        (uint32_t) (Q->ne[2] / K->ne[2]),  // repeat factor for K/V in dim 2 (MHA/MQA/GQA)
+        *(uint32_t *) &scale,              // scale (possibly adjusted for logit softcap)
+        *(uint32_t *) &max_bias,
+        *(uint32_t *) &logit_softcap,
+        *(uint32_t *) &n_head_log2,
+        *(uint32_t *) &m0,
+        *(uint32_t *) &m1
+
+    };
+    std::vector<wgpu::BindGroupEntry> entries = {
+        { .binding = 0,
+         .buffer  = ggml_webgpu_tensor_buf(Q),
+         .offset  = ggml_webgpu_tensor_align_offset(ctx, Q),
+         .size    = ggml_webgpu_tensor_binding_size(ctx, Q) },
+        { .binding = 1,
+         .buffer  = ggml_webgpu_tensor_buf(K),
+         .offset  = ggml_webgpu_tensor_align_offset(ctx, K),
+         .size    = ggml_webgpu_tensor_binding_size(ctx, K) },
+        { .binding = 2,
+         .buffer  = ggml_webgpu_tensor_buf(V),
+         .offset  = ggml_webgpu_tensor_align_offset(ctx, V),
+         .size    = ggml_webgpu_tensor_binding_size(ctx, V) }
+    };
+    uint32_t binding_index = 3;
+    if (has_mask) {
+        entries.push_back({ .binding = binding_index++,
+                            .buffer  = ggml_webgpu_tensor_buf(mask),
+                            .offset  = ggml_webgpu_tensor_align_offset(ctx, mask),
+                            .size    = ggml_webgpu_tensor_binding_size(ctx, mask) });
+    }
+    if (has_sinks) {
+        entries.push_back({ .binding = binding_index++,
+                            .buffer  = ggml_webgpu_tensor_buf(sinks),
+                            .offset  = ggml_webgpu_tensor_align_offset(ctx, sinks),
+                            .size    = ggml_webgpu_tensor_binding_size(ctx, sinks) });
+    }
+    entries.push_back({ .binding = binding_index++,
+                        .buffer  = ggml_webgpu_tensor_buf(dst),
+                        .offset  = ggml_webgpu_tensor_align_offset(ctx, dst),
+                        .size    = ggml_webgpu_tensor_binding_size(ctx, dst) });
+
+    bool kv_direct =
+        (K->type == GGML_TYPE_F16) && (Q->ne[0] % ctx->sg_mat_k == 0) && (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0);
+
+    flash_attn_pipeline_key key = {
+        .q_type             = Q->type,
+        .kv_type            = K->type,
+        .dst_type           = dst->type,
+        .head_dim_qk        = (uint32_t) Q->ne[0],
+        .head_dim_v         = (uint32_t) V->ne[0],
+        .kv_direct          = kv_direct,
+        .has_mask           = static_cast<bool>(has_mask),
+        .has_sinks          = static_cast<bool>(has_sinks),
+        .uses_logit_softcap = logit_softcap != 0.0f,
+    };
+
+    webgpu_pipeline                         pipeline;
+    ggml_webgpu_flash_attn_shader_decisions decisions = {};
+
+    auto it = ctx->flash_attn_pipelines.find(key);
+    if (it != ctx->flash_attn_pipelines.end()) {
+        pipeline  = it->second;
+        decisions = *static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context);
+    } else {
+        std::lock_guard<std::recursive_mutex> lock(ctx->mutex);
+        it = ctx->flash_attn_pipelines.find(key);
+        if (it != ctx->flash_attn_pipelines.end()) {
+            pipeline  = it->second;
+            decisions = *static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context);
+        } else {
+            ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { .kv_type     = K->type,
+                                                                         .head_dim_qk = (uint32_t) Q->ne[0],
+                                                                         .head_dim_v  = (uint32_t) V->ne[0],
+                                                                         .kv_direct   = kv_direct,
+                                                                         .has_mask    = static_cast<bool>(has_mask),
+                                                                         .has_sinks   = static_cast<bool>(has_sinks),
+                                                                         .uses_logit_softcap = logit_softcap != 0.0f,
+                                                                         .sg_mat_m           = ctx->sg_mat_m,
+                                                                         .sg_mat_n           = ctx->sg_mat_n,
+                                                                         .sg_mat_k           = ctx->sg_mat_k,
+                                                                         .wg_mem_limit_bytes =
+                                                                             ctx->limits.maxComputeWorkgroupStorageSize,
+                                                                         .max_subgroup_size = ctx->max_subgroup_size };
+
+            ggml_webgpu_processed_shader processed =
+                ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx);
+            pipeline = ggml_webgpu_create_pipeline(ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
+            pipeline.context = new ggml_webgpu_flash_attn_shader_decisions(processed.decisions);
+            ctx->flash_attn_pipelines.emplace(key, pipeline);
+            decisions = processed.decisions;
+        }
+    }
+
+    uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions.q_tile);
+    uint32_t wg_x        = wg_per_head * Q->ne[2] * Q->ne[3];  // wg per head * number of heads * number of batches
+    return ggml_backend_webgpu_build(ctx, pipeline, params, entries, wg_x);
+}
+
 static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
     uint32_t      ne       = (uint32_t) ggml_nelements(dst);
     ggml_unary_op unary_op = ggml_get_unary_op(dst);
@@ -1397,6 +1576,8 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
             return ggml_webgpu_get_rows(ctx, src0, src1, node);
         case GGML_OP_MUL_MAT:
             return ggml_webgpu_mul_mat(ctx, src0, src1, node);
+        case GGML_OP_FLASH_ATTN_EXT:
+            return ggml_webgpu_flash_attn(ctx, src0, src1, src2, node->src[3], node->src[4], node);
         case GGML_OP_ADD:
             {
                 int inplace = ggml_webgpu_tensor_equal(src0, node);
@@ -1466,6 +1647,7 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str
         webgpu_submission_futures new_futures = ggml_backend_webgpu_submit(ctx, commands);
         futures.push_back(new_futures);
     }
+
     ggml_backend_webgpu_wait(ctx, futures);
     ctx->inflight_threads--;
     WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx);
@@ -1808,15 +1990,15 @@ static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) {
 #ifndef __EMSCRIPTEN__
     if (webgpu_ctx->supports_subgroup_matrix) {
         std::map<std::string, std::string> sg_matrix_repls;
-        sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->subgroup_size);
+        sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = std::to_string(webgpu_ctx->max_subgroup_size);
         sg_matrix_repls["WEBGPU_TILE_K"]            = std::to_string(WEBGPU_MUL_MAT_TILE_K);
         sg_matrix_repls["WEBGPU_SUBGROUP_M"]        = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M);
         sg_matrix_repls["WEBGPU_SUBGROUP_N"]        = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N);
         sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M);
         sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N);
-        sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"]     = std::to_string(webgpu_ctx->subgroup_matrix_config.M);
-        sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"]     = std::to_string(webgpu_ctx->subgroup_matrix_config.N);
-        sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"]     = std::to_string(webgpu_ctx->subgroup_matrix_config.K);
+        sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"]     = std::to_string(webgpu_ctx->sg_mat_m);
+        sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"]     = std::to_string(webgpu_ctx->sg_mat_n);
+        sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"]     = std::to_string(webgpu_ctx->sg_mat_k);
 
         proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls);
         proc_mul_mat_f32_f32_vec =
@@ -2328,6 +2510,7 @@ static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) {
         webgpu_ctx->device, wgsl_soft_max_f32_mask_f16_sink_inplace, "soft_max_f32_mask_f16_sink_inplace", constants);
 }
 
+// TODO: move most initialization logic here
 static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) {
     GGML_UNUSED(params);
 
@@ -2489,6 +2672,29 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
                 }
                 break;
             }
+        case GGML_OP_FLASH_ATTN_EXT:
+            {
+                if (!webgpu_ctx->supports_subgroup_matrix) {
+                    break;
+                }
+                // Head dimensions must fit in workgroup memory with minimum tile sizes
+                size_t     limit_bytes = webgpu_ctx->limits.maxComputeWorkgroupStorageSize;
+                const bool has_mask    = op->src[3] != nullptr;
+                const bool kv_direct   = src1->type == GGML_TYPE_F16 && (src0->ne[0] % webgpu_ctx->sg_mat_k) == 0 &&
+                                       (src1->ne[1] % GGML_WEBGPU_KV_SEQ_PAD) == 0;
+                const size_t min_bytes = ggml_webgpu_flash_attn_wg_mem_bytes(
+                    webgpu_ctx->sg_mat_m, webgpu_ctx->sg_mat_n, (uint32_t) src0->ne[0], (uint32_t) src2->ne[0],
+                    has_mask, kv_direct);
+                if (min_bytes > limit_bytes) {
+                    break;
+                }
+
+                supports_op = src0->type == GGML_TYPE_F32 &&
+                              (src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16 ||
+                               src1->type == GGML_TYPE_Q4_0 || src1->type == GGML_TYPE_Q8_0) &&
+                              src2->type == src1->type && op->type == GGML_TYPE_F32;
+                break;
+            }
         case GGML_OP_RMS_NORM:
             supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32;
             break;
@@ -2606,6 +2812,7 @@ static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) {
 }
 
 // TODO: Does this need to be thread safe? Is it only called once?
+// TODO: move most logic to device_init function so backend can be freed/initialized properly
 // Only one device is supported for now
 static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t reg, size_t index) {
     GGML_ASSERT(index == 0);
@@ -2665,7 +2872,9 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
             if (config.M == config.N && config.N == config.K && (config.K == 8 || config.K == 16) &&
                 config.componentType == wgpu::SubgroupMatrixComponentType::F16 &&
                 config.resultComponentType == wgpu::SubgroupMatrixComponentType::F16) {
-                ctx->subgroup_matrix_config  = config;
+                ctx->sg_mat_m                = config.M;
+                ctx->sg_mat_n                = config.N;
+                ctx->sg_mat_k                = config.K;
                 valid_subgroup_matrix_config = true;
                 break;
             }
@@ -2676,7 +2885,7 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
 #endif
     // For subgroup matrix code to be the most efficient, we would like the subgroup size to be consistent and accurate.
     // Unfortunately, that is not possible, so we use the maximum subgroup size reported by the adapter.
-    ctx->subgroup_size = info.subgroupMaxSize;
+    ctx->max_subgroup_size = info.subgroupMaxSize;
 
     // Initialize device
     std::vector<wgpu::FeatureName> required_features = { wgpu::FeatureName::ShaderF16 };
@@ -2701,8 +2910,11 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
         wgpu::CallbackMode::AllowSpontaneous,
         [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
             GGML_UNUSED(device);
-            GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
-                           std::string(message).c_str());
+            GGML_UNUSED(reason);
+            GGML_UNUSED(message);
+            //TODO: uncomment once proper free logic is in place
+            //GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
+            //std::string(message).c_str());
         });
     dev_desc.SetUncapturedErrorCallback(
         [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
diff --git a/ggml/src/ggml-webgpu/pre_wgsl.hpp b/ggml/src/ggml-webgpu/pre_wgsl.hpp
new file mode 100644 (file)
index 0000000..4d43594
--- /dev/null
@@ -0,0 +1,778 @@
+#ifndef PRE_WGSL_HPP
+#define PRE_WGSL_HPP
+
+#include <cctype>
+#include <fstream>
+#include <sstream>
+#include <stdexcept>
+#include <string>
+#include <string_view>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+namespace pre_wgsl {
+
+//==============================================================
+// Options
+//==============================================================
+struct Options {
+    std::string              include_path = ".";
+    std::vector<std::string> macros;
+};
+
+//==============================================================
+// Utility: trim
+//==============================================================
+static std::string trim(const std::string & s) {
+    size_t a = 0;
+    while (a < s.size() && std::isspace((unsigned char) s[a])) {
+        a++;
+    }
+    size_t b = s.size();
+    while (b > a && std::isspace((unsigned char) s[b - 1])) {
+        b--;
+    }
+    return s.substr(a, b - a);
+}
+
+static std::string trim_value(std::istream & is) {
+    std::string str;
+    std::getline(is, str);
+    return trim(str);
+}
+
+static bool isIdentChar(char c) {
+    return std::isalnum(static_cast<unsigned char>(c)) || c == '_';
+}
+
+static std::string expandMacrosRecursiveInternal(const std::string &                                  line,
+                                                 const std::unordered_map<std::string, std::string> & macros,
+                                                 std::unordered_set<std::string> &                    visiting);
+
+static std::string expandMacroValue(const std::string &                                  name,
+                                    const std::unordered_map<std::string, std::string> & macros,
+                                    std::unordered_set<std::string> &                    visiting) {
+    if (visiting.count(name)) {
+        throw std::runtime_error("Recursive macro: " + name);
+    }
+    visiting.insert(name);
+
+    auto it = macros.find(name);
+    if (it == macros.end()) {
+        visiting.erase(name);
+        return name;
+    }
+
+    const std::string & value = it->second;
+    if (value.empty()) {
+        visiting.erase(name);
+        return "";
+    }
+
+    std::string expanded = expandMacrosRecursiveInternal(value, macros, visiting);
+    visiting.erase(name);
+    return expanded;
+}
+
+static std::string expandMacrosRecursiveInternal(const std::string &                                  line,
+                                                 const std::unordered_map<std::string, std::string> & macros,
+                                                 std::unordered_set<std::string> &                    visiting) {
+    std::string result;
+    result.reserve(line.size());
+
+    size_t i = 0;
+    while (i < line.size()) {
+        if (isIdentChar(line[i])) {
+            size_t start = i;
+            while (i < line.size() && isIdentChar(line[i])) {
+                i++;
+            }
+            std::string token = line.substr(start, i - start);
+
+            auto it = macros.find(token);
+            if (it != macros.end()) {
+                result += expandMacroValue(token, macros, visiting);
+            } else {
+                result += token;
+            }
+        } else {
+            result += line[i];
+            i++;
+        }
+    }
+
+    return result;
+}
+
+static std::string expandMacrosRecursive(const std::string &                                  line,
+                                         const std::unordered_map<std::string, std::string> & macros) {
+    std::unordered_set<std::string> visiting;
+    return expandMacrosRecursiveInternal(line, macros, visiting);
+}
+
+//==============================================================
+// Tokenizer for expressions in #if/#elif
+//==============================================================
+class ExprLexer {
+  public:
+    enum Kind { END, IDENT, NUMBER, OP, LPAREN, RPAREN };
+
+    struct Tok {
+        Kind        kind;
+        std::string text;
+    };
+
+    explicit ExprLexer(std::string_view sv) : src(sv), pos(0) {}
+
+    Tok next() {
+        skipWS();
+        if (pos >= src.size()) {
+            return { END, "" };
+        }
+
+        char c = src[pos];
+
+        // number
+        if (std::isdigit((unsigned char) c)) {
+            size_t start = pos;
+            while (pos < src.size() && std::isdigit((unsigned char) src[pos])) {
+                pos++;
+            }
+            return { NUMBER, std::string(src.substr(start, pos - start)) };
+        }
+
+        // identifier
+        if (std::isalpha((unsigned char) c) || c == '_') {
+            size_t start = pos;
+            while (pos < src.size() && (std::isalnum((unsigned char) src[pos]) || src[pos] == '_')) {
+                pos++;
+            }
+            return { IDENT, std::string(src.substr(start, pos - start)) };
+        }
+
+        if (c == '(') {
+            pos++;
+            return { LPAREN, "(" };
+        }
+        if (c == ')') {
+            pos++;
+            return { RPAREN, ")" };
+        }
+
+        // multi-char operators
+        static const char * two_ops[] = { "==", "!=", "<=", ">=", "&&", "||", "<<", ">>" };
+        for (auto op : two_ops) {
+            if (src.substr(pos, 2) == op) {
+                pos += 2;
+                return { OP, std::string(op) };
+            }
+        }
+
+        // single-char operators
+        if (std::string("+-*/%<>!").find(c) != std::string::npos) {
+            pos++;
+            return { OP, std::string(1, c) };
+        }
+
+        // unexpected
+        pos++;
+        return { END, "" };
+    }
+
+  private:
+    std::string_view src;
+    size_t           pos;
+
+    void skipWS() {
+        while (pos < src.size() && std::isspace((unsigned char) src[pos])) {
+            pos++;
+        }
+    }
+};
+
+//==============================================================
+// Expression Parser (recursive descent)
+//==============================================================
+class ExprParser {
+  public:
+    ExprParser(std::string_view                                     expr,
+               const std::unordered_map<std::string, std::string> & macros,
+               std::unordered_set<std::string> &                    visiting) :
+        lex(expr),
+        macros(macros),
+        visiting(visiting) {
+        advance();
+    }
+
+    int parse() { return parseLogicalOr(); }
+
+  private:
+    ExprLexer                                            lex;
+    ExprLexer::Tok                                       tok;
+    const std::unordered_map<std::string, std::string> & macros;
+    std::unordered_set<std::string> &                    visiting;
+
+    void advance() { tok = lex.next(); }
+
+    bool acceptOp(const std::string & s) {
+        if (tok.kind == ExprLexer::OP && tok.text == s) {
+            advance();
+            return true;
+        }
+        return false;
+    }
+
+    bool acceptKind(ExprLexer::Kind k) {
+        if (tok.kind == k) {
+            advance();
+            return true;
+        }
+        return false;
+    }
+
+    int parseLogicalOr() {
+        int v = parseLogicalAnd();
+        while (acceptOp("||")) {
+            int rhs = parseLogicalAnd();
+            v       = (v || rhs);
+        }
+        return v;
+    }
+
+    int parseLogicalAnd() {
+        int v = parseEquality();
+        while (acceptOp("&&")) {
+            int rhs = parseEquality();
+            v       = (v && rhs);
+        }
+        return v;
+    }
+
+    int parseEquality() {
+        int v = parseRelational();
+        for (;;) {
+            if (acceptOp("==")) {
+                int rhs = parseRelational();
+                v       = (v == rhs);
+            } else if (acceptOp("!=")) {
+                int rhs = parseRelational();
+                v       = (v != rhs);
+            } else {
+                break;
+            }
+        }
+        return v;
+    }
+
+    int parseRelational() {
+        int v = parseShift();
+        for (;;) {
+            if (acceptOp("<")) {
+                int rhs = parseShift();
+                v       = (v < rhs);
+            } else if (acceptOp(">")) {
+                int rhs = parseShift();
+                v       = (v > rhs);
+            } else if (acceptOp("<=")) {
+                int rhs = parseShift();
+                v       = (v <= rhs);
+            } else if (acceptOp(">=")) {
+                int rhs = parseShift();
+                v       = (v >= rhs);
+            } else {
+                break;
+            }
+        }
+        return v;
+    }
+
+    int parseShift() {
+        int v = parseAdd();
+        for (;;) {
+            if (acceptOp("<<")) {
+                int rhs = parseAdd();
+                v       = (v << rhs);
+            } else if (acceptOp(">>")) {
+                int rhs = parseAdd();
+                v       = (v >> rhs);
+            } else {
+                break;
+            }
+        }
+        return v;
+    }
+
+    int parseAdd() {
+        int v = parseMult();
+        for (;;) {
+            if (acceptOp("+")) {
+                int rhs = parseMult();
+                v       = (v + rhs);
+            } else if (acceptOp("-")) {
+                int rhs = parseMult();
+                v       = (v - rhs);
+            } else {
+                break;
+            }
+        }
+        return v;
+    }
+
+    int parseMult() {
+        int v = parseUnary();
+        for (;;) {
+            if (acceptOp("*")) {
+                int rhs = parseUnary();
+                v       = (v * rhs);
+            } else if (acceptOp("/")) {
+                int rhs = parseUnary();
+                v       = (rhs == 0 ? 0 : v / rhs);
+            } else if (acceptOp("%")) {
+                int rhs = parseUnary();
+                v       = (rhs == 0 ? 0 : v % rhs);
+            } else {
+                break;
+            }
+        }
+        return v;
+    }
+
+    int parseUnary() {
+        if (acceptOp("!")) {
+            return !parseUnary();
+        }
+        if (acceptOp("-")) {
+            return -parseUnary();
+        }
+        if (acceptOp("+")) {
+            return +parseUnary();
+        }
+        return parsePrimary();
+    }
+
+    int parsePrimary() {
+        // '(' expr ')'
+        if (acceptKind(ExprLexer::LPAREN)) {
+            int v = parse();
+            if (!acceptKind(ExprLexer::RPAREN)) {
+                throw std::runtime_error("missing ')'");
+            }
+            return v;
+        }
+
+        // number
+        if (tok.kind == ExprLexer::NUMBER) {
+            int v = std::stoi(tok.text);
+            advance();
+            return v;
+        }
+
+        // defined(identifier)
+        if (tok.kind == ExprLexer::IDENT && tok.text == "defined") {
+            advance();
+            if (acceptKind(ExprLexer::LPAREN)) {
+                if (tok.kind != ExprLexer::IDENT) {
+                    throw std::runtime_error("expected identifier in defined()");
+                }
+                std::string name = tok.text;
+                advance();
+                if (!acceptKind(ExprLexer::RPAREN)) {
+                    throw std::runtime_error("missing ) in defined()");
+                }
+                return macros.count(name) ? 1 : 0;
+            } else {
+                // defined NAME
+                if (tok.kind != ExprLexer::IDENT) {
+                    throw std::runtime_error("expected identifier in defined NAME");
+                }
+                std::string name = tok.text;
+                advance();
+                return macros.count(name) ? 1 : 0;
+            }
+        }
+
+        // identifier -> treat as integer, if defined use its value else 0
+        if (tok.kind == ExprLexer::IDENT) {
+            std::string name = tok.text;
+            advance();
+            auto it = macros.find(name);
+            if (it == macros.end()) {
+                return 0;
+            }
+            if (it->second.empty()) {
+                return 1;
+            }
+            return evalMacroExpression(name, it->second);
+        }
+
+        // unexpected
+        return 0;
+    }
+
+    int evalMacroExpression(const std::string & name, const std::string & value) {
+        if (visiting.count(name)) {
+            throw std::runtime_error("Recursive macro: " + name);
+        }
+
+        visiting.insert(name);
+        ExprParser ep(value, macros, visiting);
+        int        v = ep.parse();
+        visiting.erase(name);
+        return v;
+    }
+};
+
+//==============================================================
+// Preprocessor
+//==============================================================
+class Preprocessor {
+  public:
+    explicit Preprocessor(Options opts = {}) : opts_(std::move(opts)) {
+        // Treat empty include path as current directory
+        if (opts_.include_path.empty()) {
+            opts_.include_path = ".";
+        }
+        parseMacroDefinitions(opts_.macros);
+    }
+
+    std::string preprocess_file(const std::string & filename, const std::vector<std::string> & additional_macros = {}) {
+        std::unordered_map<std::string, std::string> macros;
+        std::unordered_set<std::string>              predefined;
+        std::unordered_set<std::string>              include_stack;
+        buildMacros(additional_macros, macros, predefined);
+
+        std::string result = processFile(filename, macros, predefined, include_stack, DirectiveMode::All);
+        return result;
+    }
+
+    std::string preprocess(const std::string & contents, const std::vector<std::string> & additional_macros = {}) {
+        std::unordered_map<std::string, std::string> macros;
+        std::unordered_set<std::string>              predefined;
+        std::unordered_set<std::string>              include_stack;
+        buildMacros(additional_macros, macros, predefined);
+
+        std::string result = processString(contents, macros, predefined, include_stack, DirectiveMode::All);
+        return result;
+    }
+
+    std::string preprocess_includes_file(const std::string & filename) {
+        std::unordered_map<std::string, std::string> macros;
+        std::unordered_set<std::string>              predefined;
+        std::unordered_set<std::string>              include_stack;
+        std::string result = processFile(filename, macros, predefined, include_stack, DirectiveMode::IncludesOnly);
+        return result;
+    }
+
+    std::string preprocess_includes(const std::string & contents) {
+        std::unordered_map<std::string, std::string> macros;
+        std::unordered_set<std::string>              predefined;
+        std::unordered_set<std::string>              include_stack;
+        std::string result = processString(contents, macros, predefined, include_stack, DirectiveMode::IncludesOnly);
+        return result;
+    }
+
+  private:
+    Options                                      opts_;
+    std::unordered_map<std::string, std::string> global_macros;
+
+    enum class DirectiveMode { All, IncludesOnly };
+
+    struct Cond {
+        bool parent_active;
+        bool active;
+        bool taken;
+    };
+
+    //----------------------------------------------------------
+    // Parse macro definitions into global_macros
+    //----------------------------------------------------------
+    void parseMacroDefinitions(const std::vector<std::string> & macro_defs) {
+        for (const auto & def : macro_defs) {
+            size_t eq_pos = def.find('=');
+            if (eq_pos != std::string::npos) {
+                // Format: NAME=VALUE
+                std::string name    = trim(def.substr(0, eq_pos));
+                std::string value   = trim(def.substr(eq_pos + 1));
+                global_macros[name] = value;
+            } else {
+                // Format: NAME
+                std::string name    = trim(def);
+                global_macros[name] = "";
+            }
+        }
+    }
+
+    //----------------------------------------------------------
+    // Build combined macro map and predefined set for a preprocessing operation
+    //----------------------------------------------------------
+    void buildMacros(const std::vector<std::string> &               additional_macros,
+                     std::unordered_map<std::string, std::string> & macros,
+                     std::unordered_set<std::string> &              predefined) {
+        macros = global_macros;
+        predefined.clear();
+
+        for (const auto & [name, value] : global_macros) {
+            predefined.insert(name);
+        }
+
+        for (const auto & def : additional_macros) {
+            size_t      eq_pos = def.find('=');
+            std::string name, value;
+            if (eq_pos != std::string::npos) {
+                name  = trim(def.substr(0, eq_pos));
+                value = trim(def.substr(eq_pos + 1));
+            } else {
+                name  = trim(def);
+                value = "";
+            }
+
+            // Add to macros map (will override global if same name)
+            macros[name] = value;
+            predefined.insert(name);
+        }
+    }
+
+    //----------------------------------------------------------
+    // Helpers
+    //----------------------------------------------------------
+    std::string loadFile(const std::string & fname) {
+        std::ifstream f(fname);
+        if (!f.is_open()) {
+            throw std::runtime_error("Could not open file: " + fname);
+        }
+        std::stringstream ss;
+        ss << f.rdbuf();
+        return ss.str();
+    }
+
+    bool condActive(const std::vector<Cond> & cond) const {
+        if (cond.empty()) {
+            return true;
+        }
+        return cond.back().active;
+    }
+
+    //----------------------------------------------------------
+    // Process a file
+    //----------------------------------------------------------
+    std::string processFile(const std::string &                            name,
+                            std::unordered_map<std::string, std::string> & macros,
+                            const std::unordered_set<std::string> &        predefined_macros,
+                            std::unordered_set<std::string> &              include_stack,
+                            DirectiveMode                                  mode) {
+        if (include_stack.count(name)) {
+            throw std::runtime_error("Recursive include: " + name);
+        }
+
+        include_stack.insert(name);
+        std::string shader_code = loadFile(name);
+        std::string out         = processString(shader_code, macros, predefined_macros, include_stack, mode);
+        include_stack.erase(name);
+        return out;
+    }
+
+    std::string processIncludeFile(const std::string &                            fname,
+                                   std::unordered_map<std::string, std::string> & macros,
+                                   const std::unordered_set<std::string> &        predefined_macros,
+                                   std::unordered_set<std::string> &              include_stack,
+                                   DirectiveMode                                  mode) {
+        std::string full_path = opts_.include_path + "/" + fname;
+        return processFile(full_path, macros, predefined_macros, include_stack, mode);
+    }
+
+    //----------------------------------------------------------
+    // Process text
+    //----------------------------------------------------------
+    std::string processString(const std::string &                            shader_code,
+                              std::unordered_map<std::string, std::string> & macros,
+                              const std::unordered_set<std::string> &        predefined_macros,
+                              std::unordered_set<std::string> &              include_stack,
+                              DirectiveMode                                  mode) {
+        std::vector<Cond>  cond;  // Conditional stack for this shader
+        std::stringstream  out;
+        std::istringstream in(shader_code);
+        std::string        line;
+
+        while (std::getline(in, line)) {
+            std::string t = trim(line);
+
+            if (!t.empty() && t[0] == '#') {
+                bool handled = handleDirective(t, out, macros, predefined_macros, cond, include_stack, mode);
+                if (mode == DirectiveMode::IncludesOnly && !handled) {
+                    out << line << "\n";
+                }
+            } else {
+                if (mode == DirectiveMode::IncludesOnly) {
+                    out << line << "\n";
+                } else if (condActive(cond)) {
+                    // Expand macros in the line before outputting
+                    std::string expanded = expandMacrosRecursive(line, macros);
+                    out << expanded << "\n";
+                }
+            }
+        }
+
+        if (mode == DirectiveMode::All && !cond.empty()) {
+            throw std::runtime_error("Unclosed #if directive");
+        }
+
+        return out.str();
+    }
+
+    //----------------------------------------------------------
+    // Directive handler
+    //----------------------------------------------------------
+    bool handleDirective(const std::string &                            t,
+                         std::stringstream &                            out,
+                         std::unordered_map<std::string, std::string> & macros,
+                         const std::unordered_set<std::string> &        predefined_macros,
+                         std::vector<Cond> &                            cond,
+                         std::unordered_set<std::string> &              include_stack,
+                         DirectiveMode                                  mode) {
+        // split into tokens
+        std::string        body = t.substr(1);
+        std::istringstream iss(body);
+        std::string        cmd;
+        iss >> cmd;
+
+        if (cmd == "include") {
+            if (mode == DirectiveMode::All && !condActive(cond)) {
+                return true;
+            }
+            std::string file;
+            iss >> file;
+            if (file.size() >= 2 && file.front() == '"' && file.back() == '"') {
+                file = file.substr(1, file.size() - 2);
+            }
+            out << processIncludeFile(file, macros, predefined_macros, include_stack, mode);
+            return true;
+        }
+
+        if (mode == DirectiveMode::IncludesOnly) {
+            return false;
+        }
+
+        if (cmd == "define") {
+            if (!condActive(cond)) {
+                return true;
+            }
+            std::string name;
+            iss >> name;
+            // Don't override predefined macros from options
+            if (predefined_macros.count(name)) {
+                return true;
+            }
+            std::string value = trim_value(iss);
+            macros[name]      = value;
+            return true;
+        }
+
+        if (cmd == "undef") {
+            if (!condActive(cond)) {
+                return true;
+            }
+            std::string name;
+            iss >> name;
+            // Don't undef predefined macros from options
+            if (predefined_macros.count(name)) {
+                return true;
+            }
+            macros.erase(name);
+            return true;
+        }
+
+        if (cmd == "ifdef") {
+            std::string name;
+            iss >> name;
+            bool p = condActive(cond);
+            bool v = macros.count(name);
+            cond.push_back({ p, p && v, p && v });
+            return true;
+        }
+
+        if (cmd == "ifndef") {
+            std::string name;
+            iss >> name;
+            bool p = condActive(cond);
+            bool v = !macros.count(name);
+            cond.push_back({ p, p && v, p && v });
+            return true;
+        }
+
+        if (cmd == "if") {
+            std::string expr = trim_value(iss);
+            bool        p    = condActive(cond);
+            bool        v    = false;
+            if (p) {
+                std::unordered_set<std::string> visiting;
+                ExprParser                      ep(expr, macros, visiting);
+                v = ep.parse() != 0;
+            }
+            cond.push_back({ p, p && v, p && v });
+            return true;
+        }
+
+        if (cmd == "elif") {
+            std::string expr = trim_value(iss);
+
+            if (cond.empty()) {
+                throw std::runtime_error("#elif without #if");
+            }
+
+            Cond & c = cond.back();
+            if (!c.parent_active) {
+                c.active = false;
+                return true;
+            }
+
+            if (c.taken) {
+                c.active = false;
+                return true;
+            }
+
+            std::unordered_set<std::string> visiting;
+            ExprParser                      ep(expr, macros, visiting);
+            bool                            v = ep.parse() != 0;
+            c.active                          = v;
+            if (v) {
+                c.taken = true;
+            }
+            return true;
+        }
+
+        if (cmd == "else") {
+            if (cond.empty()) {
+                throw std::runtime_error("#else without #if");
+            }
+
+            Cond & c = cond.back();
+            if (!c.parent_active) {
+                c.active = false;
+                return true;
+            }
+            if (c.taken) {
+                c.active = false;
+            } else {
+                c.active = true;
+                c.taken  = true;
+            }
+            return true;
+        }
+
+        if (cmd == "endif") {
+            if (cond.empty()) {
+                throw std::runtime_error("#endif without #if");
+            }
+            cond.pop_back();
+            return true;
+        }
+
+        // Unknown directive
+        throw std::runtime_error("Unknown directive: #" + cmd);
+    }
+};
+
+}  // namespace pre_wgsl
+
+#endif  // PRE_WGSL_HPP
diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl
new file mode 100644 (file)
index 0000000..de7c132
--- /dev/null
@@ -0,0 +1,591 @@
+diagnostic(off, chromium.subgroup_matrix_uniformity);
+diagnostic(off, subgroup_uniformity);
+enable f16;
+enable subgroups;
+enable chromium_experimental_subgroup_matrix;
+
+#ifdef KV_F32
+#define KV_TYPE f32
+#else
+#define KV_TYPE f16
+#endif
+
+// Default values
+#define HEAD_DIM_QK 64
+#define HEAD_DIM_V 64
+
+// The number of rows/columns/k in a subgroup matrix. MxK * KxN = MxN
+// Note that the "K" here does not correspond to the K in attention's Q/K/V, it's just the common dimension.
+#define SG_MAT_M 8
+#define SG_MAT_N 8
+#define SG_MAT_K 8
+
+// Each workgroup processes one subgroup matrix of Q rows
+#define Q_TILE SG_MAT_M
+#define KV_TILE 16
+#define WG_SIZE 64
+
+// Number of subgroup-matrix-width blocks that span the KV tile. SG_MAT_N must divide KV_TILE.
+#define KV_BLOCKS (KV_TILE / SG_MAT_N)
+
+// Quantization constants/helpers
+#define BLOCK_SIZE 32
+#define BLOCKS_K ((HEAD_DIM_QK + BLOCK_SIZE - 1) / BLOCK_SIZE)
+#define BLOCKS_V ((HEAD_DIM_V + BLOCK_SIZE - 1) / BLOCK_SIZE)
+// number of quantized elements processed per thread
+#if defined(KV_Q4_0)
+#define NQ 16
+// Q4_0 has 32 elements, 1 f16 for scale, 8 f16 for 4-bit weights
+#define F16_PER_BLOCK 9
+#define WEIGHTS_PER_F16 4
+#elif defined(KV_Q8_0)
+#define NQ 8
+// Q8_0 has 32 elements, 1 f16 for scale, 16 f16 for 8-bit weights
+#define F16_PER_BLOCK 17
+#define WEIGHTS_PER_F16 2
+#endif
+#define F16_PER_THREAD (NQ / WEIGHTS_PER_F16)
+
+// Ok not to put these in a define block, compiler will remove if unused
+fn get_byte(value: u32, index: u32) -> u32 {
+    return (value >> (index * 8)) & 0xFF;
+}
+
+fn get_byte_i32(value: u32, index: u32) -> i32 {
+    return bitcast<i32>(((value >> (index * 8)) & 0xFF) << 24) >> 24;
+}
+
+struct Params {
+    offset_q: u32,
+    offset_k: u32,
+    offset_v: u32,
+    offset_mask: u32,
+    offset_sinks: u32,
+    offset_dst: u32,
+
+    // shapes of Q/K/V
+    n_heads: u32,
+    seq_len_q: u32,
+    seq_len_kv: u32,
+
+    // strides (in elements)
+    stride_q1: u32,
+    stride_q2: u32,
+    stride_q3: u32,
+    stride_k1: u32,
+    stride_k2: u32,
+    stride_k3: u32,
+    stride_v1: u32,
+    stride_v2: u32,
+    stride_v3: u32,
+    stride_mask3: u32,
+
+    // repeat factors for K/V, e.g., MHA vs. MQA vs. GQA
+    q_per_kv: u32,
+
+    // softmax params
+    scale: f32,
+    max_bias: f32,
+    logit_softcap: f32,
+    n_head_log2: f32,
+    m0: f32,
+    m1: f32,
+};
+
+@group(0) @binding(0) var<storage, read_write> Q: array<f32>;
+@group(0) @binding(1) var<storage, read_write> K: array<KV_TYPE>;
+@group(0) @binding(2) var<storage, read_write> V: array<KV_TYPE>;
+
+#if defined(MASK) && defined(SINKS)
+@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
+@group(0) @binding(4) var<storage, read_write> sinks: array<f32>;
+#define DST_BINDING 5
+#define PARAMS_BINDING 6
+#elif defined(MASK)
+@group(0) @binding(3) var<storage, read_write> mask: array<f16>;
+#define DST_BINDING 4
+#define PARAMS_BINDING 5
+#elif defined(SINKS)
+@group(0) @binding(3) var<storage, read_write> sinks: array<f32>;
+#define DST_BINDING 4
+#define PARAMS_BINDING 5
+#else
+#define DST_BINDING 3
+#define PARAMS_BINDING 4
+#endif
+
+@group(0) @binding(DST_BINDING) var<storage, read_write> dst: array<f32>;
+@group(0) @binding(PARAMS_BINDING) var<uniform> params: Params;
+
+// Just a very small float value.
+const FLOAT_MIN: f32 = -1.0e9;
+
+// The number of Q rows processed per workgroup
+var<workgroup> q_shmem: array<f16, Q_TILE * HEAD_DIM_QK>;
+
+#ifndef KV_DIRECT
+const kv_shmem_size = KV_TILE * max(HEAD_DIM_QK, HEAD_DIM_V);
+// we can reuse the same shmem for K and V since we only need one at a time
+var<workgroup> kv_shmem: array<f16, kv_shmem_size>;
+#endif
+
+var<workgroup> o_shmem: array<f16, Q_TILE * HEAD_DIM_V>; // output shmem
+
+#ifdef MASK
+// storage for mask values
+var<workgroup> mask_shmem: array<f16, Q_TILE * KV_TILE>;
+#endif
+
+// storage for output of Q*K^T scores for online softmax (S matrix from paper)
+// also storage for diagonal matrix during online softmax (P matrix from paper)
+// note that we reuse the same storage for both since we only need one at a time
+var<workgroup> inter_shmem: array<f16, Q_TILE * KV_TILE>;
+
+// Storage for row max and exp sum during online softmax
+var<workgroup> row_max_shmem: array<f32, Q_TILE>;
+var<workgroup> exp_sum_shmem: array<f32, Q_TILE>;
+
+fn calc_softmax_term(kv_idx: u32, q_tile_row: u32, slope: f32) -> f32 {
+    var v = select(FLOAT_MIN,
+                   f32(inter_shmem[kv_idx + q_tile_row * KV_TILE]) * params.scale,
+                   kv_idx < KV_TILE);
+#ifdef LOGIT_SOFTCAP
+    v = params.logit_softcap * tanh(v);
+#endif
+#ifdef MASK
+    let mask_val = select(0.0, f32(mask_shmem[q_tile_row * KV_TILE + kv_idx]), kv_idx < KV_TILE);
+    let mask_term = slope * mask_val;
+    v += mask_term;
+#endif
+    return v;
+}
+
+
+@compute @workgroup_size(WG_SIZE)
+fn main(@builtin(workgroup_id) wg_id: vec3<u32>,
+        @builtin(local_invocation_id) local_id: vec3<u32>,
+        @builtin(subgroup_id) subgroup_id: u32,
+        @builtin(subgroup_size) subgroup_size: u32,
+        @builtin(num_subgroups) num_subgroups: u32,
+        @builtin(subgroup_invocation_id) sg_inv_id: u32) {
+
+    // initialize row max for online softmax
+    for (var i = local_id.x; i < Q_TILE; i += WG_SIZE) {
+        row_max_shmem[i] = FLOAT_MIN;
+        exp_sum_shmem[i] = 0.0;
+    }
+
+    for (var i = local_id.x; i < Q_TILE * HEAD_DIM_V; i += WG_SIZE) {
+        o_shmem[i] = 0.0;
+    }
+
+    // workgroups per head/batch
+    let wg_per_head = (params.seq_len_q + Q_TILE - 1u) / Q_TILE;
+    let wg_per_batch = wg_per_head * params.n_heads;
+
+    let dst2_stride = HEAD_DIM_V * params.n_heads;
+    let dst3_stride = dst2_stride * params.seq_len_q;
+
+    // batch index
+    let batch_idx = wg_id.x / wg_per_batch;
+    let q_batch_offset = params.offset_q + batch_idx * params.stride_q3;
+    let k_batch_offset = params.offset_k + batch_idx * params.stride_k3;
+    let v_batch_offset = params.offset_v + batch_idx * params.stride_v3;
+    let dst_batch_offset = params.offset_dst + batch_idx * dst3_stride;
+    let wg_in_batch = wg_id.x % wg_per_batch;
+
+    // head index
+    let head_idx = wg_in_batch / wg_per_head;
+    let q_head_offset = q_batch_offset + head_idx * params.stride_q2;
+    let k_head_idx = head_idx / params.q_per_kv;
+    let v_head_idx = k_head_idx;
+    let k_head_offset = k_batch_offset + k_head_idx * params.stride_k2;
+    let v_head_offset = v_batch_offset + v_head_idx * params.stride_v2;
+
+    // starting Q row for this workgroup
+    let wg_in_head = wg_in_batch % wg_per_head;
+    let q_row_start = wg_in_head * Q_TILE;
+
+#ifdef MASK
+    // mask offset
+    let mask_global_offset = params.offset_mask + batch_idx * params.stride_mask3 + q_row_start * params.seq_len_kv;
+#endif
+
+    // note that the output is permuted, the layout is [head_dim_v, n_heads, seq_len_q, batch_size]
+    let dst_global_offset = dst_batch_offset + q_row_start * dst2_stride + head_idx * HEAD_DIM_V;
+
+    let head = f32(head_idx);
+    let slope = select(1.0, select(pow(params.m1, 2.0 * (head - params.n_head_log2) + 1.0), pow(params.m0, head + 1.0), head < params.n_head_log2), params.max_bias > 0);
+
+    // load q tile into shared memory
+    for (var elem_idx = local_id.x; elem_idx < Q_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) {
+        let q_row = elem_idx / HEAD_DIM_QK;
+        let q_col = elem_idx % HEAD_DIM_QK;
+        let head_q_row = q_row_start + q_row;
+        let global_q_row_offset = q_head_offset + head_q_row * params.stride_q1;
+        q_shmem[elem_idx] = f16(select(
+            0.0,
+            Q[global_q_row_offset + q_col],
+            head_q_row < params.seq_len_q && q_col < HEAD_DIM_QK));
+    }
+
+    for (var kv_tile = 0u; kv_tile < params.seq_len_kv; kv_tile += KV_TILE) {
+      // clear inter_shmem to ensure zero-initialized accumulators
+      for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
+          inter_shmem[elem_idx] = 0.0;
+      }
+
+      // load k tile into shared memory
+#if defined(KV_Q4_0)
+      for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {
+          let blck_idx = elem_idx / BLOCK_SIZE;
+          let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
+          let k_row = blck_idx / BLOCKS_K;
+          let global_k_row = kv_tile + k_row;
+          let block_k = blck_idx % BLOCKS_K;
+          let row_offset = k_row * HEAD_DIM_QK;
+
+          if (global_k_row < params.seq_len_kv) {
+              let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
+              let base_idx = global_block_idx * F16_PER_BLOCK;
+              let d = K[base_idx]; // scale
+              for (var j = 0u; j < F16_PER_THREAD; j += 2) {
+                  let q_0 = K[base_idx + 1u + block_offset + j];
+                  let q_1 = K[base_idx + 1u + block_offset + j + 1];
+                  let q_packed = bitcast<u32>(vec2(q_0, q_1));
+                  for (var k = 0u; k < 4u; k++) {
+                      let q_byte = get_byte(q_packed, k);
+                      let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
+                      let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
+                      let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
+                      kv_shmem[row_offset + idx] = q_lo;
+                      kv_shmem[row_offset + idx + 16u] = q_hi;
+                  }
+              }
+          }
+      }
+#elif defined(KV_Q8_0)
+      for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE * NQ) {
+          let blck_idx = elem_idx / BLOCK_SIZE;
+          let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
+          let k_row = blck_idx / BLOCKS_K;
+          let global_k_row = kv_tile + k_row;
+          let block_k = blck_idx % BLOCKS_K;
+          let row_offset = k_row * HEAD_DIM_QK;
+
+          if (global_k_row < params.seq_len_kv) {
+              let global_block_idx = k_head_offset + global_k_row * params.stride_k1 + block_k;
+              let base_idx = global_block_idx * F16_PER_BLOCK;
+              let d = K[base_idx]; // scale
+              for (var j = 0u; j < F16_PER_THREAD; j += 2) {
+                  let q_0 = K[base_idx + 1u + block_offset + j];
+                  let q_1 = K[base_idx + 1u + block_offset + j + 1];
+                  let q_packed = bitcast<u32>(vec2(q_0, q_1));
+                  for (var k = 0u; k < 4u; k++) {
+                      let q_byte = get_byte_i32(q_packed, k);
+                      let q_val = f16(q_byte) * d;
+                      let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
+                      kv_shmem[row_offset + idx] = q_val;
+                  }
+              }
+          }
+      }
+#elif defined(KV_DIRECT)
+      // Direct global loads for KV
+#else
+      for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_QK; elem_idx += WG_SIZE) {
+          let k_row = elem_idx / HEAD_DIM_QK;
+          let k_col = elem_idx % HEAD_DIM_QK;
+          let global_k_row = kv_tile + k_row;
+          let global_k_row_offset = k_head_offset + global_k_row * params.stride_k1;
+          kv_shmem[elem_idx] = f16(select(
+              0.0,
+              K[global_k_row_offset + k_col],
+              global_k_row < params.seq_len_kv && k_col < HEAD_DIM_QK));
+      }
+#endif
+
+      workgroupBarrier();
+
+      // accumulate q block * k block into registers across the entire KV tile
+      // TODO: this loop seems to be the current largest bottleneck
+      for (var kv_block = subgroup_id; kv_block < KV_BLOCKS; kv_block += num_subgroups) {
+          let inter_offset = kv_block * SG_MAT_N;
+          var acc: subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N> = subgroupMatrixLoad<
+              subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N>>(&inter_shmem, inter_offset, false, KV_TILE);
+#ifdef KV_DIRECT
+          let k_block_row = kv_tile + kv_block * SG_MAT_N;
+          let k_global_offset = k_head_offset + k_block_row * params.stride_k1;
+#else
+          let k_block_offset = kv_block * SG_MAT_N * HEAD_DIM_QK;
+#endif
+          for (var head_dim_block = 0u; head_dim_block < HEAD_DIM_QK; head_dim_block += SG_MAT_K) {
+              // load q submatrix from shared memory
+              var q_sg_mat: subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(
+                  &q_shmem,
+                  head_dim_block,
+                  false,
+                  HEAD_DIM_QK
+              );
+
+              // load k submatrix from device or shared memory
+#ifdef KV_DIRECT
+              var k_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(
+                  &K,
+                  k_global_offset + head_dim_block,
+                  true,
+                  params.stride_k1
+              );
+#else
+              var k_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(
+                  &kv_shmem,
+                  k_block_offset + head_dim_block,
+                  true,
+                  HEAD_DIM_QK
+              );
+#endif
+              acc = subgroupMatrixMultiplyAccumulate(q_sg_mat, k_sg_mat, acc);
+          }
+
+          // store acc to shared memory for softmax (S matrix from paper)
+          subgroupMatrixStore(&inter_shmem, inter_offset, acc, false, KV_TILE);
+      }
+
+#ifdef MASK
+      // load mask tile into shared memory for this KV block
+      // TODO: optimize and skip if mask is -INF for the entire tile
+      for (var elem_idx = local_id.x; elem_idx < Q_TILE * KV_TILE; elem_idx += WG_SIZE) {
+          let mask_row = elem_idx / KV_TILE;
+          let mask_col = elem_idx % KV_TILE;
+          let global_q_row = q_row_start + mask_row;
+          let global_k_col = kv_tile + mask_col;
+          let mask_in_bounds = global_q_row < params.seq_len_q && global_k_col < params.seq_len_kv;
+          let mask_idx = mask_global_offset + mask_row * params.seq_len_kv + global_k_col;
+          mask_shmem[elem_idx] = select(0.0, mask[mask_idx], mask_in_bounds);
+      }
+#endif
+
+      workgroupBarrier();
+
+      // online softmax
+      for (var q_tile_row = subgroup_id; q_tile_row < Q_TILE; q_tile_row += num_subgroups) {
+          let global_q_row = q_row_start + q_tile_row;
+          if (global_q_row >= params.seq_len_q) {
+              break;
+          }
+
+          // initialize running max for this row
+          var prev_max = row_max_shmem[q_tile_row];
+          var final_max = prev_max;
+          // pass 1: compute final max across the full KV tile in chunks
+          for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
+              let kv_idx = kv_offset + sg_inv_id;
+              let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope);
+              final_max = subgroupMax(max(final_max, softmax_term));
+          }
+
+          var total_exp_term: f32 = 0.0;
+          // pass 2: compute exp sum and write P using final_max
+          for (var kv_offset = 0u; kv_offset < KV_TILE; kv_offset += subgroup_size) {
+              let kv_idx = kv_offset + sg_inv_id;
+              let softmax_term = calc_softmax_term(kv_idx, q_tile_row, slope);
+              let cur_p = select(0.0,
+                                 exp(softmax_term - final_max),
+                                 kv_tile + kv_idx < params.seq_len_kv && kv_idx < KV_TILE);
+              total_exp_term += subgroupAdd(cur_p);
+              if (kv_idx < KV_TILE) {
+                  inter_shmem[kv_idx + q_tile_row * KV_TILE] = f16(cur_p);
+              }
+          }
+
+          let cur_exp = exp(prev_max - final_max);
+
+          if (sg_inv_id == 0) {
+              row_max_shmem[q_tile_row] = final_max;
+              exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * cur_exp + total_exp_term;
+          }
+
+          for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
+              let idx = q_tile_row * HEAD_DIM_V + elem_idx;
+              o_shmem[idx] = f16(f32(o_shmem[idx]) * cur_exp);
+          }
+      }
+
+      // load v tile into shared memory
+#if defined(KV_Q4_0)
+      for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {
+          let blck_idx = elem_idx / BLOCK_SIZE;
+          let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
+          let v_row = blck_idx / BLOCKS_V;
+          let global_v_row = kv_tile + v_row;
+          let block_k = blck_idx % BLOCKS_V;
+          let row_offset = v_row * HEAD_DIM_V;
+
+          if (global_v_row < params.seq_len_kv) {
+              let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
+              let base_idx = global_block_idx * F16_PER_BLOCK;
+              let d = V[base_idx]; // scale
+              for (var j = 0u; j < F16_PER_THREAD; j += 2) {
+                  let q_0 = V[base_idx + 1u + block_offset + j];
+                  let q_1 = V[base_idx + 1u + block_offset + j + 1];
+                  let q_packed = bitcast<u32>(vec2(q_0, q_1));
+                  for (var k = 0u; k < 4u; k++) {
+                      let q_byte = get_byte(q_packed, k);
+                      let q_hi = (f16((q_byte >> 4) & 0xF) - 8.0) * d;
+                      let q_lo = (f16(q_byte & 0xF) - 8.0) * d;
+                      let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
+                      kv_shmem[row_offset + idx] = q_lo;
+                      kv_shmem[row_offset + idx + 16u] = q_hi;
+                  }
+              }
+          }
+      }
+#elif defined(KV_Q8_0)
+      for (var elem_idx = local_id.x * NQ; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE * NQ) {
+          let blck_idx = elem_idx / BLOCK_SIZE;
+          let block_offset = (elem_idx % BLOCK_SIZE) / WEIGHTS_PER_F16;
+          let v_row = blck_idx / BLOCKS_V;
+          let global_v_row = kv_tile + v_row;
+          let block_k = blck_idx % BLOCKS_V;
+          let row_offset = v_row * HEAD_DIM_V;
+
+          if (global_v_row < params.seq_len_kv) {
+              let global_block_idx = v_head_offset + global_v_row * params.stride_v1 + block_k;
+              let base_idx = global_block_idx * F16_PER_BLOCK;
+              let d = V[base_idx]; // scale
+              for (var j = 0u; j < F16_PER_THREAD; j += 2) {
+                  let q_0 = V[base_idx + 1u + block_offset + j];
+                  let q_1 = V[base_idx + 1u + block_offset + j + 1];
+                  let q_packed = bitcast<u32>(vec2(q_0, q_1));
+                  for (var k = 0u; k < 4u; k++) {
+                      let q_byte = get_byte_i32(q_packed, k);
+                      let q_val = f16(q_byte) * d;
+                      let idx = block_k * BLOCK_SIZE + block_offset * 2u + j * 2u + k;
+                      kv_shmem[row_offset + idx] = q_val;
+                  }
+              }
+          }
+      }
+#elif defined(KV_DIRECT)
+      // Direct global loads for KV
+#else
+      for (var elem_idx = local_id.x; elem_idx < KV_TILE * HEAD_DIM_V; elem_idx += WG_SIZE) {
+          let v_row = elem_idx / HEAD_DIM_V;
+          let v_col = elem_idx % HEAD_DIM_V;
+          let global_v_row = kv_tile + v_row;
+          let global_v_row_offset = v_head_offset + global_v_row * params.stride_v1;
+          kv_shmem[elem_idx] = f16(select(
+              0.0,
+              V[global_v_row_offset + v_col],
+              global_v_row < params.seq_len_kv && v_col < HEAD_DIM_V));
+      }
+#endif
+
+      workgroupBarrier();
+
+      // we have P (Q_TILE x KV_TILE) in inter_shmem and V (KV_TILE x head_dim_v) in kv_shmem
+      // we want to compute O += P * V across the full KV tile
+      for (var head_dim_block = subgroup_id * SG_MAT_N;
+           head_dim_block < HEAD_DIM_V;
+           head_dim_block += num_subgroups * SG_MAT_N) {
+              // load O submatrix from shared memory
+              var o_sg_mat: subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_result<f16, SG_MAT_M, SG_MAT_N>>(
+                  &o_shmem,
+                  head_dim_block,
+                  false,
+                  HEAD_DIM_V
+              );
+
+              for (var kv_block = 0u; kv_block < KV_BLOCKS; kv_block++) {
+                  let p_offset = kv_block * SG_MAT_N;
+                  var p_sg_mat: subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K> = subgroupMatrixLoad<subgroup_matrix_left<f16, SG_MAT_M, SG_MAT_K>>(
+                      &inter_shmem,
+                      p_offset,
+                      false,
+                      KV_TILE
+                  );
+
+                  // load V submatrix from global or shared memory
+#ifdef KV_DIRECT
+                  let v_block_row = kv_tile + kv_block * SG_MAT_N;
+                  let v_global_offset = v_head_offset + v_block_row * params.stride_v1 + head_dim_block;
+                  var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(
+                      &V,
+                      v_global_offset,
+                      false,
+                      params.stride_v1
+                  );
+#else
+                  let v_block_offset = kv_block * SG_MAT_N * HEAD_DIM_V;
+                  var v_sg_mat: subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N> = subgroupMatrixLoad<subgroup_matrix_right<f16, SG_MAT_K, SG_MAT_N>>(
+                      &kv_shmem,
+                      v_block_offset + head_dim_block,
+                      false,
+                      HEAD_DIM_V
+                  );
+#endif
+                  // O += P * V
+                  o_sg_mat = subgroupMatrixMultiplyAccumulate(p_sg_mat, v_sg_mat, o_sg_mat);
+              }
+
+              // store O back to shared memory
+              subgroupMatrixStore(&o_shmem, head_dim_block, o_sg_mat, false, HEAD_DIM_V);
+      }
+
+      workgroupBarrier();
+    }
+
+#ifdef SINKS
+    // add sinks (applied once after processing all KV tiles)
+    for (var q_tile_row = subgroup_id;
+         q_tile_row < Q_TILE;
+         q_tile_row += num_subgroups) {
+            // no need to process rows beyond seq_len_q
+            let global_q_row = q_row_start + q_tile_row;
+            if (global_q_row >= params.seq_len_q) {
+                break;
+            }
+
+            var prev_max = row_max_shmem[q_tile_row];
+
+            // for non-sink threads, exp(FLOAT_MIN) effectively zeroes out their contribution to the sum
+            let sink_val = select(FLOAT_MIN, sinks[params.offset_sinks + head_idx], sg_inv_id == 0);
+            let new_max = subgroupMax(max(prev_max, sink_val));
+            let max_exp = exp(prev_max - new_max);
+            let sink_exp = exp(sink_val - new_max);
+
+            let sink_exp_sum = subgroupAdd(sink_exp);
+
+            if (sg_inv_id == 0) {
+                exp_sum_shmem[q_tile_row] = exp_sum_shmem[q_tile_row] * max_exp + sink_exp_sum;
+            }
+
+            for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
+                let idx = q_tile_row * HEAD_DIM_V + elem_idx;
+                let val = f32(o_shmem[idx]) * max_exp;
+                o_shmem[idx] = f16(val);
+            }
+    }
+
+    workgroupBarrier();
+#endif
+
+    // write output back to global memory
+    for (var q_tile_row = subgroup_id;
+         q_tile_row < Q_TILE;
+         q_tile_row += num_subgroups) {
+            let global_q_row = q_row_start + q_tile_row;
+            if (global_q_row >= params.seq_len_q) {
+                break;
+            }
+
+            let exp_sum = exp_sum_shmem[q_tile_row];
+            let scale = select(0.0, 1.0 / exp_sum, exp_sum != 0);
+
+            for (var elem_idx = sg_inv_id; elem_idx < HEAD_DIM_V; elem_idx += subgroup_size) {
+                let o_val = o_shmem[q_tile_row * HEAD_DIM_V + elem_idx];
+                let scaled = f32(o_val) * scale;
+                dst[dst_global_offset + q_tile_row * dst2_stride + elem_idx] = scaled;
+            }
+    }
+}