]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: scalar flash attention implementation (#13324)
authorJeff Bolz <redacted>
Sat, 10 May 2025 06:07:07 +0000 (23:07 -0700)
committerGitHub <redacted>
Sat, 10 May 2025 06:07:07 +0000 (08:07 +0200)
* vulkan: scalar flash attention implementation

* vulkan: always use fp32 for scalar flash attention

* vulkan: use vector loads in scalar flash attention shader

* vulkan: remove PV matrix, helps with register usage

* vulkan: reduce register usage in scalar FA, but perf may be slightly worse

* vulkan: load each Q value once. optimize O reduction. more tuning

* vulkan: support q4_0/q8_0 KV in scalar FA

* CI: increase timeout to accommodate newly-supported tests

* vulkan: for scalar FA, select between 1 and 8 rows

* vulkan: avoid using Float16 capability in scalar FA

.github/workflows/build.yml
ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp [new file with mode: 0644]
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

index ae066697d9884e96cbd10659219a376963a6f423..b62720f308dd7d4947d4a1c82597d8ac142622c0 100644 (file)
@@ -307,7 +307,7 @@ jobs:
         run: |
           cd build
           # This is using llvmpipe and runs slower than other backends
-          ctest -L main --verbose --timeout 2700
+          ctest -L main --verbose --timeout 3600
 
   ubuntu-22-cmake-hip:
     runs-on: ubuntu-22.04
index 2dc2883a70ced8ef7de9c2df68cccfed5245c443..e2b357fdc1531eb2f309729155849f9db34e2dae 100644 (file)
@@ -275,6 +275,7 @@ struct vk_device_struct {
     bool prefer_host_memory;
     bool float_controls_rte_fp16;
     bool subgroup_add;
+    bool subgroup_shuffle;
 
     bool integer_dot_product;
 
@@ -402,12 +403,20 @@ struct vk_device_struct {
     vk_pipeline pipeline_conv2d_dw_cwhn_f32;
 
     // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
+    vk_pipeline pipeline_flash_attn_f32_f16_D64_cm2[GGML_TYPE_COUNT][2][2][2];
+    vk_pipeline pipeline_flash_attn_f32_f16_D80_cm2[GGML_TYPE_COUNT][2][2][2];
+    vk_pipeline pipeline_flash_attn_f32_f16_D96_cm2[GGML_TYPE_COUNT][2][2][2];
+    vk_pipeline pipeline_flash_attn_f32_f16_D112_cm2[GGML_TYPE_COUNT][2][2][2];
+    vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2];
+    vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2];
+
     vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
     vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2];
     vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2];
     vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2];
     vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
     vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
+
     vk_pipeline pipeline_flash_attn_split_k_reduce;
 
     std::unordered_map<std::string, vk_pipeline_ref> pipelines;
@@ -1581,13 +1590,29 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
 
 // number of rows/cols for flash attention shader
 static constexpr uint32_t flash_attention_num_small_rows = 32;
-static std::array<uint32_t, 2> fa_rows_cols(uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
+static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
+static constexpr uint32_t scalar_flash_attention_num_large_rows = 8;
+
+static uint32_t get_fa_num_small_rows(bool scalar) {
+    return scalar ? scalar_flash_attention_num_small_rows : flash_attention_num_small_rows;
+}
+
+static std::array<uint32_t, 2> fa_rows_cols(bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
     GGML_UNUSED(clamp);
 
+    if (scalar) {
+        if (small_rows) {
+            return {scalar_flash_attention_num_small_rows, 64};
+        } else {
+            return {scalar_flash_attention_num_large_rows, 32};
+        }
+    }
+
     // small rows, large cols
     if (small_rows) {
-        return {flash_attention_num_small_rows, 64};
+        return {get_fa_num_small_rows(scalar), 32};
     }
+
     // small cols to reduce register count
     if (ggml_is_quantized(type) || D == 256) {
         return {64, 32};
@@ -1882,65 +1907,66 @@ static void ggml_vk_load_shaders(vk_device& device) {
                                       parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
     };
 
-#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
-    if (device->coopmat2) {
-
-        auto const &fa_wg_denoms = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
-            return {fa_rows_cols(D, clamp, type, small_rows)[0], 1, 1};
-        };
+    auto const &fa_wg_denoms = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
+        return {fa_rows_cols(scalar, D, clamp, type, small_rows)[0], 1, 1};
+    };
 
-        auto const &fa_spec_constants = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
-            // For large number of rows, 128 invocations seems to work best.
-            // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
-            // can't use 256 for D==80.
-            uint32_t wg_size = (small_rows && (D % 32) == 0) ? 256 : 128;
-            auto rows_cols = fa_rows_cols(D, clamp, type, small_rows);
-            // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
-            GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
-            return {wg_size, rows_cols[0], rows_cols[1], (D), clamp};
-        };
+    auto const &fa_spec_constants = [&](bool scalar, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
+        // For large number of rows, 128 invocations seems to work best.
+        // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
+        // can't use 256 for D==80.
+        // For scalar, use 128 (arbitrary)
+        uint32_t wg_size = scalar ? 128 : ((small_rows && (D % 32) == 0) ? 256 : 128);
+        auto rows_cols = fa_rows_cols(scalar, D, clamp, type, small_rows);
+
+        // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
+        // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
+        const uint32_t D_lsb = D ^ (D & (D-1));
+        uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4);
+
+        // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
+        GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
+        return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split};
+    };
 
-#define CREATE_FA2(TYPE, NAMELC, D) \
-        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc"         #NAMELC,           flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data,  "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1);     \
-        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC,           flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data,  "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]);     \
-        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc"         #NAMELC,           flash_attn_f32_f16_ ## NAMELC ## _cm2_len,         flash_attn_f32_f16_ ## NAMELC ## _cm2_data,         "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1);     \
-        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC,           flash_attn_f32_f16_ ## NAMELC ## _cm2_len,         flash_attn_f32_f16_ ## NAMELC ## _cm2_data,         "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]);     \
-        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows"         #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data,  "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1);     \
-        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data,  "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]);     \
-        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows"         #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len,         flash_attn_f32_f16_ ## NAMELC ## _cm2_data,         "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1);     \
-        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len,         flash_attn_f32_f16_ ## NAMELC ## _cm2_data,         "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]);     \
-
-#define CREATE_FA(TYPE, NAMELC) \
-        CREATE_FA2(TYPE, NAMELC, 64) \
-        CREATE_FA2(TYPE, NAMELC, 80) \
-        CREATE_FA2(TYPE, NAMELC, 96) \
-        CREATE_FA2(TYPE, NAMELC, 112) \
-        CREATE_FA2(TYPE, NAMELC, 128) \
-        CREATE_FA2(TYPE, NAMELC, 256)
-
-        CREATE_FA(GGML_TYPE_F16, f16)
-        CREATE_FA(GGML_TYPE_Q4_0, q4_0)
-        CREATE_FA(GGML_TYPE_Q4_1, q4_1)
-        CREATE_FA(GGML_TYPE_Q5_0, q5_0)
-        CREATE_FA(GGML_TYPE_Q5_1, q5_1)
-        CREATE_FA(GGML_TYPE_Q8_0, q8_0)
-        // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
-        //CREATE_FA(GGML_TYPE_Q2_K, q2_k)
-        //CREATE_FA(GGML_TYPE_Q3_K, q3_k)
-        //CREATE_FA(GGML_TYPE_Q4_K, q4_k)
-        //CREATE_FA(GGML_TYPE_Q5_K, q5_k)
-        //CREATE_FA(GGML_TYPE_Q6_K, q6_k)
-        //CREATE_FA(GGML_TYPE_IQ1_S, iq1_s)
-        //CREATE_FA(GGML_TYPE_IQ1_M, iq1_m)
-        //CREATE_FA(GGML_TYPE_IQ2_XXS, iq2_xxs)
-        //CREATE_FA(GGML_TYPE_IQ2_XS, iq2_xs)
-        //CREATE_FA(GGML_TYPE_IQ2_S, iq2_s)
-        //CREATE_FA(GGML_TYPE_IQ3_XXS, iq3_xxs)
-        //CREATE_FA(GGML_TYPE_IQ3_S, iq3_s)
-        //CREATE_FA(GGML_TYPE_IQ4_XS, iq4_xs)
-        CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl)
+#define CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, D) \
+        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc"         #NAMELC #SUFFIX,           flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true);     \
+        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX,           flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true);     \
+        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc"         #NAMELC #SUFFIX,           flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _len,         flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _data,         "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,false), fa_spec_constants(SCALAR, D,1,TYPE,false), 1, true);     \
+        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX,           flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _len,         flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _data,         "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,false), fa_spec_constants(SCALAR, D,0,TYPE,false), fa_rows_cols(SCALAR,D,0,TYPE,false)[1], true);     \
+        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows"         #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true);     \
+        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true);     \
+        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows"         #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _len,         flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _data,         "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,1,TYPE,true), fa_spec_constants(SCALAR, D,1,TYPE,true), 1, true);     \
+        ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _len,         flash_attn_f32_f16_ ## NAMELC ##     SUFFIX ## _data,         "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(SCALAR, D,0,TYPE,true), fa_spec_constants(SCALAR, D,0,TYPE,true), fa_rows_cols(SCALAR,D,0,TYPE,true)[1], true);     \
+
+#define CREATE_FA(TYPE, NAMELC, SCALAR, SUFFIX) \
+        CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 64) \
+        CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 80) \
+        CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 96) \
+        CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 112) \
+        CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 128) \
+        CREATE_FA2(TYPE, NAMELC, SCALAR, SUFFIX, 256)
+
+    CREATE_FA(GGML_TYPE_F16, f16, true, )
+    CREATE_FA(GGML_TYPE_Q4_0, q4_0, true, )
+    CREATE_FA(GGML_TYPE_Q8_0, q8_0, true, )
+#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
+    if (device->coopmat2) {
+        CREATE_FA(GGML_TYPE_F16, f16, false, _cm2)
+        CREATE_FA(GGML_TYPE_Q4_0, q4_0, false, _cm2)
+        CREATE_FA(GGML_TYPE_Q4_1, q4_1, false, _cm2)
+        CREATE_FA(GGML_TYPE_Q5_0, q5_0, false, _cm2)
+        CREATE_FA(GGML_TYPE_Q5_1, q5_1, false, _cm2)
+        CREATE_FA(GGML_TYPE_Q8_0, q8_0, false, _cm2)
+        CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, false, _cm2)
+    }
+#endif
+#undef CREATE_FA2
 #undef CREATE_FA
 
+#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
+    if (device->coopmat2) {
+
         // Create 6 variants, {s,m,l}x{unaligned,aligned}
 #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
         ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1);   \
@@ -2837,6 +2863,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
         device->subgroup_add = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
                                (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
 
+        device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
+                                   (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle);
+
         const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
 
         device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
@@ -5709,20 +5738,57 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
     assert(q->type == GGML_TYPE_F32);
     assert(k->type == v->type);
 
+    bool scalar = !ctx->device->coopmat2;
+
+    uint32_t gqa_ratio = 1;
+    uint32_t qk_ratio = neq2 / nek2;
+    uint32_t workgroups_x = (uint32_t)neq1;
+    uint32_t workgroups_y = (uint32_t)neq2;
+    uint32_t workgroups_z = (uint32_t)neq3;
+
+    // For scalar FA, we can use the "large" size to accommodate qga.
+    // For coopmat FA, we always use the small size (which is still pretty large for gqa).
+    const uint32_t max_gqa = scalar ? scalar_flash_attention_num_large_rows : get_fa_num_small_rows(false);
+
+    if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
+        qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
+        // grouped query attention - make the N dimension equal to gqa_ratio, reduce
+        // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
+        // and change addressing calculations to index Q's dimension 2.
+        gqa_ratio = qk_ratio;
+        N = gqa_ratio;
+        workgroups_y /= N;
+    }
+
     vk_pipeline *pipelines;
     // XXX TODO other backends may be changing accumulator precision to default to f32 soon
-    bool f32acc = dst->op_params[3] == GGML_PREC_F32;
-    bool small_rows = N <= flash_attention_num_small_rows;
-    switch (D) {
-    case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
-    case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
-    case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break;
-    case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break;
-    case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break;
-    case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break;
-    default:
-        assert(!"unsupported D value");
-        return;
+    bool f32acc = scalar || dst->op_params[3] == GGML_PREC_F32;
+    bool small_rows = N <= get_fa_num_small_rows(scalar);
+
+    if (scalar) {
+        switch (D) {
+        case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
+        case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
+        case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break;
+        case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break;
+        case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break;
+        case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break;
+        default:
+            GGML_ASSERT(!"unsupported D value");
+            return;
+        }
+    } else {
+        switch (D) {
+        case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break;
+        case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break;
+        case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm2[k->type][f32acc][small_rows][0]; break;
+        case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm2[k->type][f32acc][small_rows][0]; break;
+        case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm2[k->type][f32acc][small_rows][0]; break;
+        case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm2[k->type][f32acc][small_rows][0]; break;
+        default:
+            GGML_ASSERT(!"unsupported D value");
+            return;
+        }
     }
     assert(pipelines);
 
@@ -5740,27 +5806,14 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
     vk_pipeline pipeline = pipelines[aligned];
     assert(pipeline);
 
-    uint32_t gqa_ratio = 1;
-    uint32_t qk_ratio = neq2 / nek2;
-    uint32_t workgroups_x = (uint32_t)neq1;
-    uint32_t workgroups_y = (uint32_t)neq2;
-    uint32_t workgroups_z = (uint32_t)neq3;
-
-    if (N == 1 && qk_ratio > 1 && gqa_ratio <= flash_attention_num_small_rows &&
-        qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
-        // grouped query attention - make the N dimension equal to gqa_ratio, reduce
-        // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
-        // and change addressing calculations to index Q's dimension 2.
-        gqa_ratio = qk_ratio;
-        N = gqa_ratio;
-        workgroups_y /= N;
-    }
-
     uint32_t split_kv = KV;
     uint32_t split_k = 1;
 
+    // Use a placeholder core count if one isn't available. split_k is a big help for perf.
+    const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
+
     // Try to use split_k when KV is large enough to be worth the overhead
-    if (workgroups_x == 1 && ctx->device->shader_core_count > 0 && KV >= 512) {
+    if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
         // Try to run two workgroups per SM.
         split_k = ctx->device->shader_core_count * 2 / workgroups_y;
         if (split_k > 1) {
@@ -9530,9 +9583,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
         case GGML_OP_FLASH_ATTN_EXT:
             {
                 ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
-                if (!ggml_vk_get_device(ctx->device)->coopmat2) {
-                    return false;
-                }
+                auto device = ggml_vk_get_device(ctx->device);
+                bool coopmat2 = device->coopmat2;
                 switch (op->src[0]->ne[0]) {
                 case 64:
                 case 80:
@@ -9540,7 +9592,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 case 112:
                 case 128:
                 case 256:
-                case 575: // DeepSeek MLA
                     break;
                 default:
                     return false;
@@ -9566,10 +9617,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 switch (op->src[1]->type) {
                 case GGML_TYPE_F16:
                 case GGML_TYPE_Q4_0:
+                case GGML_TYPE_Q8_0:
+                    // supported in scalar and coopmat2 paths
+                    break;
                 case GGML_TYPE_Q4_1:
                 case GGML_TYPE_Q5_0:
                 case GGML_TYPE_Q5_1:
-                case GGML_TYPE_Q8_0:
                 // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
                 //case GGML_TYPE_Q2_K:
                 //case GGML_TYPE_Q3_K:
@@ -9585,10 +9638,18 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 //case GGML_TYPE_IQ3_S:
                 //case GGML_TYPE_IQ4_XS:
                 case GGML_TYPE_IQ4_NL:
+                    // currently supported only in coopmat2 path
+                    if (!coopmat2) {
+                        return false;
+                    }
                     break;
                 default:
                     return false;
                 }
+                if (!coopmat2 && !device->subgroup_shuffle) {
+                    // scalar FA uses subgroupShuffle
+                    return false;
+                }
                 return true;
             }
         case GGML_OP_GET_ROWS:
diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp
new file mode 100644 (file)
index 0000000..e654516
--- /dev/null
@@ -0,0 +1,483 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : enable
+#extension GL_EXT_shader_16bit_storage : require
+
+#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
+#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
+
+#extension GL_KHR_shader_subgroup_shuffle : enable
+
+#include "types.comp"
+
+layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
+
+layout (constant_id = 1) const uint32_t Br = 1;
+layout (constant_id = 2) const uint32_t Bc = 32;
+layout (constant_id = 3) const uint32_t D = 32;
+
+layout (constant_id = 5) const uint32_t D_split = 16;
+const uint32_t D_per_thread = D / D_split;
+
+const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split;
+const uint32_t cols_per_thread = Bc / cols_per_iter;
+
+layout (push_constant) uniform parameter {
+    uint32_t N;
+    uint32_t KV;
+
+    uint32_t ne1;
+    uint32_t ne2;
+    uint32_t ne3;
+
+    uint32_t neq2;
+    uint32_t neq3;
+    uint32_t nek2;
+    uint32_t nek3;
+    uint32_t nev2;
+    uint32_t nev3;
+    uint32_t nem1;
+
+    uint32_t nb01;
+    uint32_t nb02;
+    uint32_t nb03;
+    uint32_t nb11;
+    uint32_t nb12;
+    uint32_t nb13;
+    uint32_t nb21;
+    uint32_t nb22;
+    uint32_t nb23;
+    uint32_t nb31;
+
+    float scale;
+    float max_bias;
+    float logit_softcap;
+
+    uint32_t mask;
+    uint32_t n_head_log2;
+    float m0;
+    float m1;
+
+    uint32_t gqa_ratio;
+    uint32_t split_kv;
+    uint32_t k_num;
+} p;
+
+layout (binding = 0) readonly buffer Q {float data_q[];};
+layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];};
+layout (binding = 1) readonly buffer K {float16_t data_k[];};
+layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];};
+layout (binding = 2) readonly buffer V {float16_t data_v[];};
+layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
+layout (binding = 3) readonly buffer M {float16_t data_m[];};
+layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
+
+#if defined(A_TYPE_PACKED16)
+#define BINDING_IDX_K 0
+#define BINDING_IDX_V 1
+layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
+#endif
+
+#if defined(DATA_A_Q4_0)
+#define BLOCK_BYTE_SIZE 18
+
+vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
+    uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
+    uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
+    uint shift = (iqs & 0x10) >> 2;
+    vui_lo >>= shift;
+    vui_hi >>= shift;
+
+    return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
+}
+#endif
+
+#if defined(DATA_A_Q8_0)
+#define BLOCK_BYTE_SIZE 34
+vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
+    const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
+    const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
+
+    return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
+}
+#endif
+
+#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
+
+// Store the output when doing grouped query attention.
+// Rows index by Q's dimension 2, and the first N rows are valid.
+D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
+{
+    uint32_t offset = (iq2 + r) * D + c;
+    data_o[o_offset + offset] = D_TYPE(elem);
+    return elem;
+}
+
+// Store column zero. This is used to save per-row m and L values for split_k.
+ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
+{
+    if (r < N && c == 0) {
+        uint32_t offset = iq2 + r;
+        data_o[o_offset + offset] = D_TYPE(elem);
+    }
+    return elem;
+}
+
+// Load the slope matrix, indexed by Q's dimension 2.
+ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
+{
+    const uint32_t h = iq2 + (r % p.gqa_ratio);
+
+    const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
+    const int      exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
+
+    return ACC_TYPE(pow(base, ACC_TYPE(exph)));
+}
+
+shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
+shared vec4 tmpshv4[gl_WorkGroupSize.x];
+
+shared float masksh[Bc][Br];
+shared vec4 Qf[Br][D / 4];
+
+void main() {
+#ifdef NEEDS_INIT_IQ_SHMEM
+    init_iq_shmem(gl_WorkGroupSize);
+#endif
+
+    const uint32_t tid = gl_LocalInvocationIndex;
+    const uint32_t N = p.N;
+    const uint32_t KV = p.KV;
+
+    const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
+    const uint32_t col_tid = gl_LocalInvocationIndex / D_split;
+
+    uint32_t i = gl_WorkGroupID.x;
+    uint32_t split_k_index = 0;
+
+    if (p.k_num > 1) {
+        i = 0;
+        split_k_index = gl_WorkGroupID.x;
+    }
+
+    const uint32_t Tr = CEIL_DIV(N, Br);
+
+    const uint32_t start_j = split_k_index * p.split_kv / Bc;
+    const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
+
+    // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
+    // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
+    const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
+    const uint32_t iq3 = gl_WorkGroupID.z;
+
+    // broadcast factors
+    const uint32_t rk2 = p.neq2/p.nek2;
+    const uint32_t rk3 = p.neq3/p.nek3;
+
+    const uint32_t rv2 = p.neq2/p.nev2;
+    const uint32_t rv3 = p.neq3/p.nev3;
+
+    // k indices
+    const uint32_t ik3 = iq3 / rk3;
+    const uint32_t ik2 = iq2 / rk2;
+
+    // v indices
+    const uint32_t iv3 = iq3 / rv3;
+    const uint32_t iv2 = iq2 / rv2;
+
+    // nb?1 are already divided by the type size and are in units of elements.
+    // When using grouped query attention, Q is indexed by iq2, so the stride
+    // should be nb02 (which is in bytes).
+    uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
+    uint32_t k_stride = p.nb11;
+    uint32_t v_stride = p.nb21;
+    // When using grouped query attention, all rows use the same mask (stride 0).
+    // "p.gqa_ratio >> 16" is just a roundabout way of writing zero
+    // that prevents the compiler from folding the "&" through the select
+    // and breaking the alignment detection.
+    uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
+
+    uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4;
+
+    [[unroll]] for (uint32_t idx = 0; idx < Br * D / 4; idx += gl_WorkGroupSize.x) {
+        uint32_t d = (idx + tid) % (D / 4);
+        uint32_t r = (idx + tid) / (D / 4);
+        if (r < Br && d < D / 4 &&
+            i * Br + r < N) {
+            Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale;
+        }
+    }
+    barrier();
+
+    vec4 Of[Br][D_per_thread / 4];
+    [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+            Of[r][d] = vec4(0.0);
+        }
+    }
+
+    float Lf[Br], Mf[Br];
+
+    // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
+    const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
+
+    [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+        Lf[r] = 0;
+        Mf[r] = NEG_FLT_MAX_OVER_2;
+    }
+
+    float slope[Br];
+    [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+        slope[r] = 1.0;
+    }
+
+    // ALiBi
+    if (p.max_bias > 0.0f) {
+        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+            slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);
+        }
+    }
+
+#if BLOCK_SIZE > 1
+    uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE;
+    uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE;
+#else
+    uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2;
+    uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2;
+#endif
+
+    [[dont_unroll]]
+    for (uint32_t j = start_j; j < end_j; ++j) {
+
+        float Sf[Br][cols_per_thread];
+        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+            [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+                Sf[r][c] = 0.0;
+            }
+        }
+
+
+        [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+            [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+#if BLOCK_SIZE > 1
+                uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
+                uint ib = coord / BLOCK_SIZE;
+                uint iqs = (coord % BLOCK_SIZE);
+                vec4 K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
+#else
+                vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
+#endif
+                [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+                    Sf[r][c] += dot(Qf[r][d * D_split + d_tid], K_Tf);
+                }
+            }
+        }
+
+        [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+            // Compute sum across the D_split
+            [[unroll]] for (uint s = D_split / 2; s > 0; s >>= 1) {
+                [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+                    Sf[r][c] += subgroupShuffleXor(Sf[r][c], s);
+                }
+            }
+        }
+
+        if (p.logit_softcap != 0.0f) {
+            [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+                [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+                    Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]);
+                }
+            }
+        }
+
+        if (p.mask != 0) {
+
+            [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
+                uint32_t c = (idx + tid) % Bc;
+                uint32_t r = (idx + tid) / Bc;
+                if (idx + tid < Bc * Br) {
+                    masksh[c][r] = float(data_m[(i * Br + r) * m_stride + (j * Bc + c)]);
+                }
+            }
+            barrier();
+
+            [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+                [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+                    float mvf = masksh[c * cols_per_iter + col_tid][r];
+
+                    Sf[r][c] += slope[r]*mvf;
+                }
+            }
+            barrier();
+        }
+
+        float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br];
+        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+            rowmaxf[r] = Sf[r][0];
+            [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+                rowmaxf[r] = max(rowmaxf[r], Sf[r][c]);
+            }
+            Moldf[r] = Mf[r];
+
+            // M = max(rowmax, Mold)
+            // P = e^(S - M)
+            // eM = e^(Mold - M)
+            Mf[r] = max(rowmaxf[r], Moldf[r]);
+            [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+                Pf[r][c] = exp(Sf[r][c] - Mf[r]);
+            }
+            eMf[r] = exp(Moldf[r] - Mf[r]);
+
+            // Compute sum across row of P
+            rowsumf[r] = 0.0;
+            [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+                rowsumf[r] += Pf[r][c];
+            }
+
+            Lf[r] = eMf[r]*Lf[r] + rowsumf[r];
+        }
+
+        [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+            [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+                Of[r][d] = eMf[r] * Of[r][d];
+            }
+        }
+
+        [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+            [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+#if BLOCK_SIZE > 1
+                uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
+                uint ib = coord / BLOCK_SIZE;
+                uint iqs = (coord % BLOCK_SIZE);
+                vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
+#else
+                vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
+#endif
+                [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+                    Of[r][d] += Pf[r][c] * Vf;
+                }
+            }
+        }
+
+        barrier();
+    }
+
+    // reduce across threads
+
+    [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+        float rowmaxf, eMf;
+
+        tmpsh[tid] = Mf[r];
+        // Compute max across the row
+        barrier();
+        [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
+            if (tid < s) {
+                tmpsh[tid] = max(tmpsh[tid], tmpsh[tid + s]);
+            }
+            barrier();
+        }
+        rowmaxf = tmpsh[d_tid];
+        barrier();
+
+        float Moldf = Mf[r];
+
+        // M = max(rowmax, Mold)
+        // eM = e^(Mold - M)
+        Mf[r] = max(rowmaxf, Moldf);
+        eMf = exp(Moldf - Mf[r]);
+
+        Lf[r] = eMf*Lf[r];
+
+        tmpsh[tid] = Lf[r];
+
+        // Compute sum across the row
+        barrier();
+        [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
+            if (tid < s) {
+                tmpsh[tid] = tmpsh[tid] + tmpsh[tid + s];
+            }
+            barrier();
+        }
+        Lf[r] = tmpsh[d_tid];
+        barrier();
+
+        [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+
+            Of[r][d] = eMf * Of[r][d];
+            tmpshv4[tid] = Of[r][d];
+
+            barrier();
+            [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
+                if (tid < s) {
+                    Of[r][d] += tmpshv4[tid + s];
+                    tmpshv4[tid] = Of[r][d];
+                }
+                barrier();
+            }
+            Of[r][d] = tmpshv4[d_tid];
+            barrier();
+        }
+    }
+
+
+    // If there is split_k, then the split_k resolve shader does the final
+    // division by L. Store the intermediate O value and per-row m and L values.
+    if (p.k_num > 1) {
+        uint32_t o_offset = D * p.ne1 * split_k_index;
+
+        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+            if (r < N) {
+                [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+                    [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
+                        perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
+                    }
+                }
+            }
+        }
+
+        o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
+        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+            if (r < N) {
+                perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
+                perElemOpStoreCol0(r, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
+            }
+        }
+
+        return;
+    }
+
+    float Lfrcp[Br];
+    [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+        Lfrcp[r] = 1.0 / Lf[r];
+    }
+
+    [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+            Of[r][d] *= Lfrcp[r];
+        }
+    }
+
+    uint32_t o_offset = iq3*p.ne2*p.ne1;
+
+    if (p.gqa_ratio > 1) {
+        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+            if (r < N) {
+                [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+                    [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
+                        perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
+                    }
+                }
+            }
+        }
+    } else {
+        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+            if (i * Br + r < N) {
+                [[unroll]] for (uint32_t d = 0; d < D_per_thread / 4; ++d) {
+                    [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
+                        data_o[o_offset + iq2 * D + (i * Br + r) * p.ne1 * D + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
+                    }
+                }
+            }
+        }
+    }
+}
index 382f190f287bb6e3ed30bc2cb97c13f61741de2b..d196137eb292ed6486c638b7aab8b2717d15f248 100644 (file)
@@ -421,7 +421,6 @@ void process_shaders() {
 #endif
     }
 
-#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
     // flash attention
     for (const auto& f16acc : {false, true}) {
         std::string acctype = f16acc ? "float16_t" : "float";
@@ -432,6 +431,7 @@ void process_shaders() {
             }
             if (tname == "bf16") continue;
 
+#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
             if (tname == "f16") {
                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
                     merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc);
@@ -440,9 +440,17 @@ void process_shaders() {
                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
                     merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
             }
+#endif
+            if (tname == "f16") {
+                string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
+                    merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, false, f16acc);
+            } else if (tname == "q4_0" || tname == "q8_0") {
+                std::string data_a_key = "DATA_A_" + to_uppercase(tname);
+                string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
+                    merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
+            }
         }
     }
-#endif
 
     for (const auto& tname : type_names) {
         // mul mat vec