]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commit
ggml webgpu: initial flashattention implementation (llama/18610)
authorReese Levine <redacted>
Thu, 8 Jan 2026 16:23:39 +0000 (08:23 -0800)
committerGeorgi Gerganov <redacted>
Wed, 14 Jan 2026 07:11:59 +0000 (09:11 +0200)
commit1bb903f5991a4bf7a9bc6ad181dbe813ebd924f3
tree8d351707a64924e2225ca63d941cc1d7a5d3a59b
parent0bc0e5616eee463d74dae7633b7ca512bc69e2bc
ggml webgpu: initial flashattention implementation (llama/18610)

* FlashAttention (llama/13)

* Add inplace softmax

* Move rms_norm to split row approach

* Update debug for supports_op

* clean up debug statements

* neg f16xf32xip builds and runs, havent actually ran a model that uses neg kernel yet though

* neg passes backend test

* unary operators pass ggml tests

* rms_norm double declaration bug atoned

* abides by editor-config

* removed vestigial files

* fixed autoconfig

* All operators (inlcluding xielu) working

* removed unnecesarry checking if node->src[1] exists for unary operators

* responded and dealt with PR comments

* implemented REPL_Template support and removed bug in unary operators kernel

* formatted embed wgsl and ggml-webgpu.cpp

* Faster tensors (llama/8)

Add fast matrix and matrix/vector multiplication.

* Use map for shader replacements instead of pair of strings

* Wasm (llama/9)

* webgpu : fix build on emscripten

* more debugging stuff

* test-backend-ops: force single thread on wasm

* fix single-thread case for init_tensor_uniform

* use jspi

* add pthread

* test: remember to set n_thread for cpu backend

* Add buffer label and enable dawn-specific toggles to turn off some checks

* Intermediate state

* Fast working f16/f32 vec4

* Working float fast mul mat

* Clean up naming of mul_mat to match logical model, start work on q mul_mat

* Setup for subgroup matrix mat mul

* Basic working subgroup matrix

* Working subgroup matrix tiling

* Handle weirder sg matrix sizes (but still % sg matrix size)

* Working start to gemv

* working f16 accumulation with shared memory staging

* Print out available subgroup matrix configurations

* Vectorize dst stores for sg matrix shader

* Gemv working scalar

* Minor set_rows optimization (llama/4)

* updated optimization, fixed errors

* non vectorized version now dispatches one thread per element

* Simplify

* Change logic for set_rows pipelines

---------

Co-authored-by: Neha Abbas <redacted>
Co-authored-by: Neha Abbas <redacted>
Co-authored-by: Reese Levine <redacted>
* Comment on dawn toggles

* Working subgroup matrix code for (semi)generic sizes

* Remove some comments

* Cleanup code

* Update dawn version and move to portable subgroup size

* Try to fix new dawn release

* Update subgroup size comment

* Only check for subgroup matrix configs if they are supported

* Add toggles for subgroup matrix/f16 support on nvidia+vulkan

* Make row/col naming consistent

* Refactor shared memory loading

* Move sg matrix stores to correct file

* Working q4_0

* Formatting

* Work with emscripten builds

* Fix test-backend-ops emscripten for f16/quantized types

* Use emscripten memory64 to support get_memory

* Add build flags and try ci

---------

Co-authored-by: Xuan Son Nguyen <redacted>
* Remove extra whitespace

* Move wasm single-thread logic out of test-backend-ops for cpu backend

* Disable multiple threads for emscripten single-thread builds in ggml_graph_plan

* Refactored pipelines and workgroup calculations (llama/10)

* refactored pipelines

* refactored workgroup calculation

* removed commented out block of prior maps

* Clean up ceiling division pattern

---------

Co-authored-by: Neha Abbas <redacted>
Co-authored-by: Reese Levine <redacted>
* Start work on flash attention

* Shader structure set up (many bugs still)

* debugging

* Working first test

* Working with head grouping, head sizes to 128, logit softcap, mask/sinks enabled, f32

* Generalize softmax to work with multiple subgroups, f16 accumulation, mask shared memory tiling

* Start work on integrating pre-wgsl

* Separate structs/initial shader compilation library into separate files

* Work on compilation choices for flashattention

* Work on subgroup matrix/tile size portability

* subgroup size agnostic online softmax

* Cleanups, quantization types

* more cleanup

* fix wasm build

* Refactor flashattention to increase parallelism, use direct loads for KV in somce cases

* Checkpoint

* formatting

* Update to account for default kv cache padding

* formatting shader

* Add workflow for ggml-ci webgpu

* Try passing absolute path to dawn in ggml-ci

* Avoid error on device destruction, add todos for proper cleanup

* Fix unused warning

* Forgot one parameter unused

* Move some flashattn computation to f32 for correctness
ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp [new file with mode: 0644]
ggml/src/ggml-webgpu/ggml-webgpu.cpp
ggml/src/ggml-webgpu/pre_wgsl.hpp [new file with mode: 0644]
ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl [new file with mode: 0644]