]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
vulkan: move common FA code to flash_attn_base.comp (llama/13556)
authorJeff Bolz <redacted>
Sat, 17 May 2025 07:14:55 +0000 (16:14 +0900)
committerGeorgi Gerganov <redacted>
Mon, 19 May 2025 11:58:39 +0000 (14:58 +0300)
* vulkan: move common FA code to flash_attn_base.comp

* vulkan: move common FA index/stride setup code to flash_attn_base.comp

* build fix

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp [new file with mode: 0644]
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

index 1683557681439ae26c9b0ce52bce5f711dc6014d..ce230a8f7d91038ed84f2ba4841052f110e5fbf7 100644 (file)
@@ -9,60 +9,13 @@
 #extension GL_KHR_shader_subgroup_shuffle : enable
 
 #include "types.comp"
+#include "flash_attn_base.comp"
 
-layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
-
-layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
-layout (constant_id = 1) const uint32_t Br = 1;
-layout (constant_id = 2) const uint32_t Bc = 32;
-layout (constant_id = 3) const uint32_t D = 32;
-
-layout (constant_id = 5) const uint32_t D_split = 16;
 const uint32_t D_per_thread = D / D_split;
 
 const uint32_t cols_per_iter = WorkGroupSize / D_split;
 const uint32_t cols_per_thread = Bc / cols_per_iter;
 
-layout (push_constant) uniform parameter {
-    uint32_t N;
-    uint32_t KV;
-
-    uint32_t ne1;
-    uint32_t ne2;
-    uint32_t ne3;
-
-    uint32_t neq2;
-    uint32_t neq3;
-    uint32_t nek2;
-    uint32_t nek3;
-    uint32_t nev2;
-    uint32_t nev3;
-    uint32_t nem1;
-
-    uint32_t nb01;
-    uint32_t nb02;
-    uint32_t nb03;
-    uint32_t nb11;
-    uint32_t nb12;
-    uint32_t nb13;
-    uint32_t nb21;
-    uint32_t nb22;
-    uint32_t nb23;
-    uint32_t nb31;
-
-    float scale;
-    float max_bias;
-    float logit_softcap;
-
-    uint32_t mask;
-    uint32_t n_head_log2;
-    float m0;
-    float m1;
-
-    uint32_t gqa_ratio;
-    uint32_t split_kv;
-    uint32_t k_num;
-} p;
 
 layout (binding = 0) readonly buffer Q {float data_q[];};
 layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
@@ -71,39 +24,6 @@ layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
 layout (binding = 2) readonly buffer V {float16_t data_v[];};
 layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
 layout (binding = 3) readonly buffer M {float16_t data_m[];};
-layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
-
-#if defined(A_TYPE_PACKED16)
-#define BINDING_IDX_K 0
-#define BINDING_IDX_V 1
-layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
-#endif
-
-#if defined(DATA_A_Q4_0)
-#define BLOCK_BYTE_SIZE 18
-
-vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
-    uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
-    uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
-    uint shift = (iqs & 0x10) >> 2;
-    vui_lo >>= shift;
-    vui_hi >>= shift;
-
-    return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
-}
-#endif
-
-#if defined(DATA_A_Q8_0)
-#define BLOCK_BYTE_SIZE 34
-vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
-    const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
-    const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
-
-    return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
-}
-#endif
-
-#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
 
 // Store the output when doing grouped query attention.
 // Rows index by Q's dimension 2, and the first N rows are valid.
@@ -114,27 +34,6 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
     return elem;
 }
 
-// Store column zero. This is used to save per-row m and L values for split_k.
-ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
-{
-    if (r < N && c == 0) {
-        uint32_t offset = iq2 + r;
-        data_o[o_offset + offset] = D_TYPE(elem);
-    }
-    return elem;
-}
-
-// Load the slope matrix, indexed by Q's dimension 2.
-ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
-{
-    const uint32_t h = iq2 + (r % p.gqa_ratio);
-
-    const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
-    const int      exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
-
-    return ACC_TYPE(pow(base, ACC_TYPE(exph)));
-}
-
 shared FLOAT_TYPE tmpsh[WorkGroupSize];
 shared vec4 tmpshv4[WorkGroupSize];
 
@@ -146,58 +45,12 @@ void main() {
     init_iq_shmem(gl_WorkGroupSize);
 #endif
 
-    const uint32_t tid = gl_LocalInvocationIndex;
-    const uint32_t N = p.N;
-    const uint32_t KV = p.KV;
+    init_indices();
 
+    const uint32_t tid = gl_LocalInvocationIndex;
     const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
     const uint32_t col_tid = gl_LocalInvocationIndex / D_split;
 
-    uint32_t i = gl_WorkGroupID.x;
-    uint32_t split_k_index = 0;
-
-    if (p.k_num > 1) {
-        i = 0;
-        split_k_index = gl_WorkGroupID.x;
-    }
-
-    const uint32_t Tr = CEIL_DIV(N, Br);
-
-    const uint32_t start_j = split_k_index * p.split_kv / Bc;
-    const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
-
-    // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
-    // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
-    const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
-    const uint32_t iq3 = gl_WorkGroupID.z;
-
-    // broadcast factors
-    const uint32_t rk2 = p.neq2/p.nek2;
-    const uint32_t rk3 = p.neq3/p.nek3;
-
-    const uint32_t rv2 = p.neq2/p.nev2;
-    const uint32_t rv3 = p.neq3/p.nev3;
-
-    // k indices
-    const uint32_t ik3 = iq3 / rk3;
-    const uint32_t ik2 = iq2 / rk2;
-
-    // v indices
-    const uint32_t iv3 = iq3 / rv3;
-    const uint32_t iv2 = iq2 / rv2;
-
-    // nb?1 are already divided by the type size and are in units of elements.
-    // When using grouped query attention, Q is indexed by iq2, so the stride
-    // should be nb02 (which is in bytes).
-    uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
-    uint32_t k_stride = p.nb11;
-    uint32_t v_stride = p.nb21;
-    // When using grouped query attention, all rows use the same mask (stride 0).
-    // "p.gqa_ratio >> 16" is just a roundabout way of writing zero
-    // that prevents the compiler from folding the "&" through the select
-    // and breaking the alignment detection.
-    uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
-
     uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
 
     [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp
new file mode 100644 (file)
index 0000000..61d90e2
--- /dev/null
@@ -0,0 +1,162 @@
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
+layout (constant_id = 1) const uint32_t Br = 1;
+layout (constant_id = 2) const uint32_t Bc = 32;
+layout (constant_id = 3) const uint32_t D = 32;
+layout (constant_id = 4) const uint32_t Clamp = 0;
+layout (constant_id = 5) const uint32_t D_split = 16;
+
+
+layout (push_constant) uniform parameter {
+    uint32_t N;
+    uint32_t KV;
+
+    uint32_t ne1;
+    uint32_t ne2;
+    uint32_t ne3;
+
+    uint32_t neq2;
+    uint32_t neq3;
+    uint32_t nek2;
+    uint32_t nek3;
+    uint32_t nev2;
+    uint32_t nev3;
+    uint32_t nem1;
+
+    uint32_t nb01;
+    uint32_t nb02;
+    uint32_t nb03;
+    uint32_t nb11;
+    uint32_t nb12;
+    uint32_t nb13;
+    uint32_t nb21;
+    uint32_t nb22;
+    uint32_t nb23;
+    uint32_t nb31;
+
+    float scale;
+    float max_bias;
+    float logit_softcap;
+
+    uint32_t mask;
+    uint32_t n_head_log2;
+    float m0;
+    float m1;
+
+    uint32_t gqa_ratio;
+    uint32_t split_kv;
+    uint32_t k_num;
+} p;
+
+layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
+
+#if defined(A_TYPE_PACKED16)
+#define BINDING_IDX_K 0
+#define BINDING_IDX_V 1
+layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
+#endif
+
+#if defined(DATA_A_Q4_0)
+#define BLOCK_BYTE_SIZE 18
+
+vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
+    uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
+    uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
+    uint shift = (iqs & 0x10) >> 2;
+    vui_lo >>= shift;
+    vui_hi >>= shift;
+
+    return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
+}
+#endif
+
+#if defined(DATA_A_Q8_0)
+#define BLOCK_BYTE_SIZE 34
+vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
+    const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
+    const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
+
+    return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
+}
+#endif
+
+#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
+
+
+// Store column zero. This is used to save per-row m and L values for split_k.
+ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
+{
+    if (r < N && c == 0) {
+        uint32_t offset = iq2 + r;
+        data_o[o_offset + offset] = D_TYPE(elem);
+    }
+    return elem;
+}
+
+// Load the slope matrix, indexed by Q's dimension 2.
+ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
+{
+    const uint32_t h = iq2 + (r % p.gqa_ratio);
+
+    const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
+    const int      exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
+
+    return ACC_TYPE(pow(base, ACC_TYPE(exph)));
+}
+
+uint32_t i, N, KV, split_k_index, Tr, start_j, end_j,
+         iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3,
+         q_stride, k_stride, v_stride, m_stride;
+
+void init_indices()
+{
+    N = p.N;
+    KV = p.KV;
+
+    i = gl_WorkGroupID.x;
+    split_k_index = 0;
+
+    if (p.k_num > 1) {
+        i = 0;
+        split_k_index = gl_WorkGroupID.x;
+    }
+
+    Tr = CEIL_DIV(N, Br);
+
+    start_j = split_k_index * p.split_kv / Bc;
+    end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
+
+    // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
+    // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
+    iq2 = gl_WorkGroupID.y * p.gqa_ratio;
+    iq3 = gl_WorkGroupID.z;
+
+    // broadcast factors
+    rk2 = p.neq2/p.nek2;
+    rk3 = p.neq3/p.nek3;
+
+    rv2 = p.neq2/p.nev2;
+    rv3 = p.neq3/p.nev3;
+
+    // k indices
+    ik3 = iq3 / rk3;
+    ik2 = iq2 / rk2;
+
+    // v indices
+    iv3 = iq3 / rv3;
+    iv2 = iq2 / rv2;
+
+    // nb?1 are already divided by the type size and are in units of elements.
+    // When using grouped query attention, Q is indexed by iq2, so the stride
+    // should be nb02 (which is in bytes).
+    q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
+    k_stride = p.nb11;
+    v_stride = p.nb21;
+    // When using grouped query attention, all rows use the same mask (stride 0).
+    // "p.gqa_ratio >> 16" is just a roundabout way of writing zero
+    // that prevents the compiler from folding the "&" through the select
+    // and breaking the alignment detection.
+    m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
+}
index 8b86b623bd94d621787731c6a6f75fd39a0ea65c..da478be24fb6e7298ddff54438bf00313075a9c2 100644 (file)
 #extension GL_KHR_cooperative_matrix : enable
 
 #include "types.comp"
-
-layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
-
-layout (constant_id = 1) const uint32_t Br = 1;
-layout (constant_id = 2) const uint32_t Bc = 32;
-layout (constant_id = 3) const uint32_t D = 32;
-
-layout (constant_id = 5) const uint32_t D_split = 16;
+#include "flash_attn_base.comp"
 
 const uint32_t D_per_thread = D / D_split;
 const uint32_t row_split = 4;
@@ -26,46 +19,6 @@ const uint32_t rows_per_thread = Br / row_split;
 const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
 const uint32_t cols_per_thread = Bc / cols_per_iter;
 
-layout (push_constant) uniform parameter {
-    uint32_t N;
-    uint32_t KV;
-
-    uint32_t ne1;
-    uint32_t ne2;
-    uint32_t ne3;
-
-    uint32_t neq2;
-    uint32_t neq3;
-    uint32_t nek2;
-    uint32_t nek3;
-    uint32_t nev2;
-    uint32_t nev3;
-    uint32_t nem1;
-
-    uint32_t nb01;
-    uint32_t nb02;
-    uint32_t nb03;
-    uint32_t nb11;
-    uint32_t nb12;
-    uint32_t nb13;
-    uint32_t nb21;
-    uint32_t nb22;
-    uint32_t nb23;
-    uint32_t nb31;
-
-    float scale;
-    float max_bias;
-    float logit_softcap;
-
-    uint32_t mask;
-    uint32_t n_head_log2;
-    float m0;
-    float m1;
-
-    uint32_t gqa_ratio;
-    uint32_t split_kv;
-    uint32_t k_num;
-} p;
 
 layout (binding = 0) readonly buffer Q {float data_q[];};
 layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
@@ -74,39 +27,6 @@ layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
 layout (binding = 2) readonly buffer V {float16_t data_v[];};
 layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
 layout (binding = 3) readonly buffer M {float16_t data_m[];};
-layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
-
-#if defined(A_TYPE_PACKED16)
-#define BINDING_IDX_K 0
-#define BINDING_IDX_V 1
-layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
-#endif
-
-#if defined(DATA_A_Q4_0)
-#define BLOCK_BYTE_SIZE 18
-
-vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
-    uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
-    uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
-    uint shift = (iqs & 0x10) >> 2;
-    vui_lo >>= shift;
-    vui_hi >>= shift;
-
-    return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
-}
-#endif
-
-#if defined(DATA_A_Q8_0)
-#define BLOCK_BYTE_SIZE 34
-vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
-    const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
-    const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
-
-    return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
-}
-#endif
-
-#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
 
 // Store the output when doing grouped query attention.
 // Rows index by Q's dimension 2, and the first N rows are valid.
@@ -117,27 +37,6 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
     return elem;
 }
 
-// Store column zero. This is used to save per-row m and L values for split_k.
-ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
-{
-    if (r < N && c == 0) {
-        uint32_t offset = iq2 + r;
-        data_o[o_offset + offset] = D_TYPE(elem);
-    }
-    return elem;
-}
-
-// Load the slope matrix, indexed by Q's dimension 2.
-ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
-{
-    const uint32_t h = iq2 + (r % p.gqa_ratio);
-
-    const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
-    const int      exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
-
-    return ACC_TYPE(pow(base, ACC_TYPE(exph)));
-}
-
 // These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
 const uint32_t MatBr = 16;
 const uint32_t MatBc = 16;
@@ -162,9 +61,9 @@ void main() {
     init_iq_shmem(gl_WorkGroupSize);
 #endif
 
+    init_indices();
+
     const uint32_t tid = gl_LocalInvocationIndex;
-    const uint32_t N = p.N;
-    const uint32_t KV = p.KV;
 
     const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;
     const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;
@@ -173,51 +72,6 @@ void main() {
 
 #define tile_row(r) (row_tid * rows_per_thread + (r))
 
-    uint32_t i = gl_WorkGroupID.x;
-    uint32_t split_k_index = 0;
-
-    if (p.k_num > 1) {
-        i = 0;
-        split_k_index = gl_WorkGroupID.x;
-    }
-
-    const uint32_t Tr = CEIL_DIV(N, Br);
-
-    const uint32_t start_j = split_k_index * p.split_kv / Bc;
-    const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
-
-    // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
-    // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
-    const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
-    const uint32_t iq3 = gl_WorkGroupID.z;
-
-    // broadcast factors
-    const uint32_t rk2 = p.neq2/p.nek2;
-    const uint32_t rk3 = p.neq3/p.nek3;
-
-    const uint32_t rv2 = p.neq2/p.nev2;
-    const uint32_t rv3 = p.neq3/p.nev3;
-
-    // k indices
-    const uint32_t ik3 = iq3 / rk3;
-    const uint32_t ik2 = iq2 / rk2;
-
-    // v indices
-    const uint32_t iv3 = iq3 / rv3;
-    const uint32_t iv2 = iq2 / rv2;
-
-    // nb?1 are already divided by the type size and are in units of elements.
-    // When using grouped query attention, Q is indexed by iq2, so the stride
-    // should be nb02 (which is in bytes).
-    uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
-    uint32_t k_stride = p.nb11;
-    uint32_t v_stride = p.nb21;
-    // When using grouped query attention, all rows use the same mask (stride 0).
-    // "p.gqa_ratio >> 16" is just a roundabout way of writing zero
-    // that prevents the compiler from folding the "&" through the select
-    // and breaking the alignment detection.
-    uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
-
     uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
 
     [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
index b926a578aded6df20c6b40442cdd06c2648b7d71..6acf67a03a46351fbfd1d28540f951f02ef5648e 100644 (file)
 
 #include "types.comp"
 #include "dequant_funcs_cm2.comp"
-
-layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
-
-layout (constant_id = 1) const uint32_t Br = 32;
-layout (constant_id = 2) const uint32_t Bc = 32;
-layout (constant_id = 3) const uint32_t D = 32;
-layout (constant_id = 4) const uint32_t Clamp = gl_CooperativeMatrixClampModeConstantNV;
-
-layout (push_constant) uniform parameter {
-    uint32_t N;
-    uint32_t KV;
-
-    uint32_t ne1;
-    uint32_t ne2;
-    uint32_t ne3;
-
-    uint32_t neq2;
-    uint32_t neq3;
-    uint32_t nek2;
-    uint32_t nek3;
-    uint32_t nev2;
-    uint32_t nev3;
-    uint32_t nem1;
-
-    uint32_t nb01;
-    uint32_t nb02;
-    uint32_t nb03;
-    uint32_t nb11;
-    uint32_t nb12;
-    uint32_t nb13;
-    uint32_t nb21;
-    uint32_t nb22;
-    uint32_t nb23;
-    uint32_t nb31;
-
-    float scale;
-    float max_bias;
-    float logit_softcap;
-
-    uint32_t mask;
-    uint32_t n_head_log2;
-    float m0;
-    float m1;
-
-    uint32_t gqa_ratio;
-    uint32_t split_kv;
-    uint32_t k_num;
-} p;
+#include "flash_attn_base.comp"
 
 layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
 layout (binding = 1) readonly buffer K {uint8_t data_k[];};
 layout (binding = 2) readonly buffer V {uint8_t data_v[];};
 layout (binding = 3) readonly buffer M {uint8_t data_m[];};
-layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
-
-#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
 
 ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
     return max(x, y);
@@ -118,67 +68,12 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
     return elem;
 }
 
-// Store column zero. This is used to save per-row m and L values for split_k.
-ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
-{
-    if (r < N && c == 0) {
-        uint32_t offset = iq2 + r;
-        data_o[o_offset + offset] = D_TYPE(elem);
-    }
-    return elem;
-}
-
-// Load the slope matrix, indexed by Q's dimension 2.
-ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
-{
-    const uint32_t h = iq2 + (r % p.gqa_ratio);
-
-    const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
-    const int      exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
-
-    return ACC_TYPE(pow(base, ACC_TYPE(exph)));
-}
-
 void main() {
 #ifdef NEEDS_INIT_IQ_SHMEM
     init_iq_shmem(gl_WorkGroupSize);
 #endif
 
-    const uint32_t N = p.N;
-    const uint32_t KV = p.KV;
-
-    uint32_t i = gl_WorkGroupID.x;
-    uint32_t split_k_index = 0;
-
-    if (p.k_num > 1) {
-        i = 0;
-        split_k_index = gl_WorkGroupID.x;
-    }
-
-    const uint32_t Tr = CEIL_DIV(N, Br);
-
-    const uint32_t start_j = split_k_index * p.split_kv / Bc;
-    const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
-
-    // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
-    // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
-    const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
-    const uint32_t iq3 = gl_WorkGroupID.z;
-
-    // broadcast factors
-    const uint32_t rk2 = p.neq2/p.nek2;
-    const uint32_t rk3 = p.neq3/p.nek3;
-
-    const uint32_t rv2 = p.neq2/p.nev2;
-    const uint32_t rv3 = p.neq3/p.nev3;
-
-    // k indices
-    const uint32_t ik3 = iq3 / rk3;
-    const uint32_t ik2 = iq2 / rk2;
-
-    // v indices
-    const uint32_t iv3 = iq3 / rv3;
-    const uint32_t iv2 = iq2 / rv2;
+    init_indices();
 
     tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
     tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp);
@@ -195,17 +90,6 @@ void main() {
     tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
     tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
 
-    // nb?1 are already divided by the type size and are in units of elements.
-    // When using grouped query attention, Q is indexed by iq2, so the stride
-    // should be nb02 (which is in bytes).
-    uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
-    uint32_t k_stride = p.nb11;
-    uint32_t v_stride = p.nb21;
-    // When using grouped query attention, all rows use the same mask (stride 0).
-    // "p.gqa_ratio >> 16" is just a roundabout way of writing zero
-    // that prevents the compiler from folding the "&" through the select
-    // and breaking the alignment detection.
-    uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
     // hint to the compiler that strides are aligned for the aligned variant of the shader
     if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
     {