]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Vulkan Scalar Flash Attention Refactor (#19625)
authorRuben Ortlam <redacted>
Tue, 24 Feb 2026 07:35:48 +0000 (08:35 +0100)
committerGitHub <redacted>
Tue, 24 Feb 2026 07:35:48 +0000 (08:35 +0100)
* vulkan: allow using fp16 in scalar flash attention shader

* split rows inside of subgroups for faster synchronization

* use row_split when Br >= 4, change reductions to use shared memory if row_split == 1

* use f32 scalar FA if f16 is not supported by device

* fix amd workgroup size issue

* optimize masksh use

* add medium rows FA shader Br size

* fixes

* add padding to mask shmem buffer

* cache q values into registers for KQ

* fuse lf accumulation, pf and v accumulation into a loop

* stage K loads through shmem

* stage V loads through shmem

* only stage through shmem on Nvidia

* default to Bc 32

* also stage V through shmem when this is done for K

* dynamic subgroups for intel

* use vectorized stores

* use float_type for dequantize4 functions

* use smaller scalar rows size for smaller rows count

* relax flash attention split_k condition to allow non-gqa use

* use minimal subgroup size on Intel

* fix shmem support function

* fix rebase issues

* fixes

* Bc 4 for scalar FA is not a valid configuration

* Use wave32 on AMD RDNA for scalar FA

* add Intel shader core count lookup-table

* fix regressions

* device tuning

* tmpsh size fix

* fix editorconfig

* refactor fa tuning logic into a single place

* fix gqa opt logic

* fix block_rows with small n_rows

* amd tuning

* fix hsk=72/80 issue

* tuning

* allow condition skipping for column check

* use float16 for Of if available

* address feedback

* fix bad RDNA performance on head size <= 128 by limiting occupancy

* allow printing pipeline stats

* cleanup and fixes

* limit occupancy for GCN for small batch FA with large HSK

* disable f16 FA for GCN AMD GPUs on the proprietary driver

ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

index 88b3e4e58eb16f10e1f78fa6ed67c1b88ad4b5d4..8a9cfaf16547afe90d01bdf44f6a40ec2fbbdebe 100644 (file)
@@ -403,19 +403,20 @@ enum FaCodePath {
 };
 
 struct vk_fa_pipeline_state {
-    vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc, uint32_t flags)
-        : HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc), flags(flags) {}
-
     uint32_t HSK, HSV;
-    bool small_rows, small_cache;
+    uint32_t Br, Bc;
+    uint32_t D_split, row_split;
+    bool shmem_staging;
     FaCodePath path;
+    uint32_t workgroup_size, subgroup_size;
     bool aligned;
     bool f32acc;
     uint32_t flags;
+    uint32_t limit_occupancy_shmem;
 
     bool operator<(const vk_fa_pipeline_state &b) const {
-        return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, flags) <
-               std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc, b.flags);
+        return std::tie(HSK, HSV, Br, Bc, D_split, row_split, shmem_staging, path, workgroup_size, subgroup_size, aligned, f32acc, flags, limit_occupancy_shmem) <
+               std::tie(b.HSK, b.HSV, b.Br, b.Bc, b.D_split, b.row_split, b.shmem_staging, b.path, b.workgroup_size, b.subgroup_size, b.aligned, b.f32acc, b.flags, b.limit_occupancy_shmem);
     }
 };
 
@@ -623,6 +624,8 @@ struct vk_device_struct {
     // floor(log2(maxComputeWorkGroupInvocations))
     uint32_t max_workgroup_size_log2 {};
 
+    bool flash_attention_fp16;
+
     bool coopmat_support;
     bool coopmat_acc_f32_support {};
     bool coopmat_acc_f16_support {};
@@ -1656,6 +1659,7 @@ static bool vk_perf_logger_concurrent = false;
 static bool vk_enable_sync_logger = false;
 // number of calls between perf logger prints
 static uint32_t vk_perf_logger_frequency = 1;
+static std::string vk_pipeline_stats_filter;
 
 class vk_perf_logger {
   public:
@@ -2172,7 +2176,32 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
         executableInfo.pipeline = pipeline->pipeline;
 
         auto statistics = device->device.getPipelineExecutableStatisticsKHR(executableInfo);
+
+        bool print_stats = !vk_pipeline_stats_filter.empty() &&
+                           pipeline->name.find(vk_pipeline_stats_filter) != std::string::npos;
+        if (print_stats) {
+            std::cerr << "ggml_vulkan: pipeline stats for " << pipeline->name << ":" << std::endl;
+        }
+
         for (auto & s : statistics) {
+            if (print_stats) {
+                std::cerr << "ggml_vulkan:   " << s.name.data() << ": ";
+                switch (s.format) {
+                    case vk::PipelineExecutableStatisticFormatKHR::eBool32:
+                        std::cerr << (s.value.b32 ? "true" : "false");
+                        break;
+                    case vk::PipelineExecutableStatisticFormatKHR::eInt64:
+                        std::cerr << s.value.i64;
+                        break;
+                    case vk::PipelineExecutableStatisticFormatKHR::eUint64:
+                        std::cerr << s.value.u64;
+                        break;
+                    case vk::PipelineExecutableStatisticFormatKHR::eFloat64:
+                        std::cerr << s.value.f64;
+                        break;
+                }
+                std::cerr << std::endl;
+            }
             // "Register Count" is reported by NVIDIA drivers.
             if (strcmp(s.name, "Register Count") == 0) {
                 VK_LOG_DEBUG(pipeline->name << " " << s.name << ": " << s.value.u64 << " registers");
@@ -2755,78 +2784,214 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
     );
 }
 
-// number of rows/cols for flash attention shader
-static constexpr uint32_t flash_attention_num_small_rows = 32;
-static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
+struct vk_fa_tuning_params {
+    FaCodePath path;
+    uint32_t workgroup_size;
+    uint32_t subgroup_size;
+    uint32_t block_rows;
+    uint32_t block_cols;
+    uint32_t d_split;
+    uint32_t row_split;
+    bool shmem_staging;
+    bool disable_subgroups;
+    uint32_t limit_occupancy_shmem;
+
+    void print() const {
+        std::cerr << "path=" << path << " workgroup_size=" << workgroup_size << " subgroup_size=" << subgroup_size <<
+                     " block_rows=" << block_rows << " block_cols=" << block_cols << " d_split=" << d_split <<
+                     " row_split=" << row_split << " shmem_staging=" << shmem_staging << " disable_subgroups=" << disable_subgroups <<
+                     " limit_occupancy_shmem=" << limit_occupancy_shmem << std::endl;
+    }
+};
+
+static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
+static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
+
+static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
+    GGML_UNUSED(kv_type);
+
+    vk_fa_tuning_params result{};
+    result.path = FA_SCALAR;
 
-static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv, bool small_cache) {
-    if (hsv >= 192) {
-        return 2;
-    } else if ((hsv | hsk) & 8 || small_cache) {
-        return 4;
+    if (device->vendor_id == VK_VENDOR_ID_INTEL) {
+        // Disable subgroup use due to performance issues when enforcing subgroup sizes
+        result.subgroup_size = 32;
+        result.disable_subgroups = true;
+    } else if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN) {
+        result.subgroup_size = n_rows < 4 ? 32 : device->subgroup_size;
     } else {
-        return 8;
+        result.subgroup_size = device->subgroup_size;
     }
-}
 
-// The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
-// 128 threads split into four subgroups, each subgroup does 1/4
-// of the Bc dimension.
-static constexpr uint32_t coopmat1_flash_attention_num_large_rows = 16;
-static constexpr uint32_t scalar_flash_attention_Bc = 64;
-static constexpr uint32_t scalar_flash_attention_workgroup_size = 128;
+    // Row split splits the workgroup so that synchronization only has to happen within subgroups, which avoids barriers
+    uint32_t row_split_max_hsk = 64;
+    if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN && !device->uma) {
+        row_split_max_hsk = n_rows <= 8 ? 64 : 128;
+    }
+    result.row_split = (n_rows < 4 || hsk <= row_split_max_hsk) ? 1 : 4;
 
-static uint32_t get_fa_num_small_rows(FaCodePath path) {
-    if (path == FA_COOPMAT2) {
-        return flash_attention_num_small_rows;
+    if (result.subgroup_size > 32 && (n_rows < 4 || hsk < (result.row_split == 1 ? 128 : 64))) {
+        result.workgroup_size = result.subgroup_size * 2;
     } else {
-        return scalar_flash_attention_num_small_rows;
+        result.workgroup_size = result.subgroup_size * 4;
     }
-}
 
-static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) {
-    GGML_UNUSED(clamp);
+    const uint32_t D = hsk | hsv;
+
+    const bool reduce_block_rows = D & 8 || n_kv < 1024 || device->vendor_id == VK_VENDOR_ID_INTEL;
 
-    if (path == FA_SCALAR) {
-        if (small_rows) {
-            return {scalar_flash_attention_num_small_rows, 64};
+    if (n_rows == 1) {
+        result.block_rows = 1;
+        result.block_cols = 64;
+    } else {
+        // row_split 1 means higher register use per row, so block size has to be adjusted
+        if (result.row_split == 1) {
+            result.block_rows = n_rows == 2 ? 2 : ((n_rows <= 4 || reduce_block_rows) ? 4 : 8);
         } else {
-            if ((hsv | hsk) & 8) {
-                // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter
-                // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not.
-                return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 64};
-            } else {
-                return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 32};
-            }
+            result.block_rows = n_rows <= 4 ? 4 : ((n_rows <= 8 || reduce_block_rows) ? 8 : 16);
         }
+
+        result.block_cols = (D & 8) ? 64 : 32;
     }
 
-    if (path == FA_COOPMAT1) {
-        if (small_rows) {
-            return {scalar_flash_attention_num_small_rows, scalar_flash_attention_Bc};
-        } else {
-            return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc};
+    const uint32_t D_lsb = D ^ (D & (D-1));  // extract lowest set bit
+
+    result.d_split = std::min(std::min(result.subgroup_size, 8u), D_lsb / 4);
+
+    result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0;
+
+    if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc)) {
+        result.block_rows /= 2;
+    }
+
+    // On AMD RDNA, for small head sizes and big batch size the shader uses few registers, so too many subgroups get scheduled
+    // at once and end up thrashing the cache. Fix this by setting a large (unused) shmem buffer that reduces occupancy.
+    // This targets an occupancy of 4 subgroups per SIMD.
+    if (device->vendor_id == VK_VENDOR_ID_AMD && device->properties.limits.maxComputeSharedMemorySize == 65536) {
+        if (device->architecture != AMD_GCN && n_rows >= 64 && hsk <= 128) {
+            // 30kb target for hsk > 64, 26kb for <= 64 due to smaller workgroup size
+            // Values are guessed, tested on RDNA2
+            result.limit_occupancy_shmem = (hsk <= 64 ? 26 : 30) * 1024 / 4 / 4;
+        } else if (device->architecture == AMD_GCN && n_rows <= 8 && hsk >= 256) {
+            // Same thing for GCN, with an occupancy target of 2 subgroups per SIMD.
+            // Here low-batch FA with large head size is affected.
+            // n_rows < 4 switch because workgroup size switches from 128 to 256 there.
+            result.limit_occupancy_shmem = (n_rows < 4 ? 14 : 26) * 1024 / 4 / 4;
         }
     }
 
-    // small rows, large cols
+    return result;
+}
+
+static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
+    GGML_UNUSED(n_rows);
+    GGML_UNUSED(n_kv);
+    GGML_UNUSED(kv_type);
+    GGML_UNUSED(f32acc);
+
+    vk_fa_tuning_params result{};
+    result.path = FA_COOPMAT1;
+
+    const uint32_t D = hsk | hsv;
+
+    const uint32_t coopmat_block_rows = 16;
+    const uint32_t coopmat_block_cols = 16;
+
+    const uint32_t num_subgroups = 4;
+
+    result.block_rows = coopmat_block_rows;
+    result.block_cols = coopmat_block_cols * num_subgroups;
+    result.row_split = num_subgroups;
+    result.subgroup_size = device->subgroup_size;
+    result.workgroup_size = num_subgroups * result.subgroup_size;
+
+    const uint32_t D_lsb = D ^ (D & (D-1));  // extract lowest set bit
+    result.d_split = std::min(std::min(result.subgroup_size, 8u), D_lsb / 4);
+
+    result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0;
+
+    return result;
+}
+
+static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
+    GGML_UNUSED(n_kv);
+    GGML_UNUSED(f32acc);
+
+    vk_fa_tuning_params result{};
+    result.path = FA_COOPMAT2;
+
+    const uint32_t D = hsk | hsv;
+
+    const bool small_rows = n_rows < 32;
+
     if (small_rows) {
-        return {get_fa_num_small_rows(FA_COOPMAT2), 32};
+        result.block_rows = 32;
+        result.block_cols = 32;
+    } else if (ggml_is_quantized(kv_type) || hsk >= 256 || hsv >= 256) {
+        result.block_rows = (hsk >= 512 || hsv >= 512) ? 32 : 64;
+        result.block_cols = 32;
+    } else {
+        result.block_rows = 64;
+        result.block_cols = 64;
     }
 
-    // small cols to reduce register count
-    if (ggml_is_quantized(type) || hsk >= 256 || hsv >= 256) {
-        if (hsk >= 512 || hsv >= 512) {
-            return {32, 32};
-        } else {
-            return {64, 32};
+    result.subgroup_size = device->subgroup_size;
+    result.workgroup_size = (small_rows && (D % 32) == 0) ? 256 : 128;
+
+    return result;
+}
+
+static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
+    FaCodePath path = device->coopmat2 ? FA_COOPMAT2 :
+                      device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
+
+    if (path == FA_COOPMAT1 && device->architecture == vk_device_architecture::NVIDIA_TURING) {
+        // Nvidia compiler bug, see https://github.com/ggml-org/llama.cpp/pull/19075#issuecomment-3820716090
+        path = FA_SCALAR;
+    }
+
+    if (path == FA_COOPMAT1) {
+        bool shape_ok = (f32acc && device->coopmat_support_16x16x16_f32acc) ||
+                        (!f32acc && device->coopmat_support_16x16x16_f16acc);
+        const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
+        bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc);
+
+        if (!shape_ok || !shmem_ok) {
+            path = FA_SCALAR;
         }
     }
-    return {64, 64};
+
+    // scalar is faster than coopmat when N==1
+    if (n_rows == 1 && (path == FA_COOPMAT1 || path == FA_COOPMAT2)) {
+        path = FA_SCALAR;
+    }
+
+    switch (path) {
+    case FA_SCALAR:
+        return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
+    case FA_COOPMAT1:
+        return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
+    case FA_COOPMAT2:
+        return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
+    default:
+        throw std::runtime_error("unsupported FaCodePath");
+    }
+}
+
+static vk_fa_pipeline_state get_fa_pipeline_state(const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool aligned, bool f32acc,
+                                                  bool use_mask, bool use_mask_opt, bool use_logit_softcap) {
+    uint32_t flags = (use_mask_opt      ? 1 : 0) |
+                     (use_mask          ? 2 : 0) |
+                     (use_logit_softcap ? 4 : 0);
+
+    const uint32_t subgroup_size = params.disable_subgroups ? 0 : params.subgroup_size;
+
+    return vk_fa_pipeline_state{hsk, hsv, params.block_rows, params.block_cols, params.d_split, params.row_split, params.shmem_staging, params.path, params.workgroup_size, subgroup_size, aligned, f32acc, flags, params.limit_occupancy_shmem};
 }
 
-static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows, bool small_cache) {
-    return fa_rows_cols(path, hsk, hsv, 0, type, small_rows, small_cache)[1];
+static std::vector<uint32_t> get_fa_spec_constants(const vk_fa_pipeline_state& state) {
+    return {state.workgroup_size, state.Br, state.Bc, state.HSK, state.HSV, !state.aligned, state.D_split,
+            state.row_split, state.subgroup_size, state.shmem_staging ? 1u : 0u, state.flags, state.limit_occupancy_shmem};
 }
 
 static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
@@ -3193,76 +3358,43 @@ static void ggml_vk_load_shaders(vk_device& device) {
                                        align, disable_robustness, require_full_subgroups, required_subgroup_size);
     };
 
-    auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::array<uint32_t, 3> {
-        return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1};
-    };
-
-    auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache, uint32_t flags) -> std::vector<uint32_t> {
-        // For large number of rows, 128 invocations seems to work best.
-        // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
-        // can't use 256 for D==80.
-        // For scalar, use 128 (arbitrary)
-        // The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs.
-        const uint32_t D = (hsk|hsv);
-        auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache);
-
-        uint32_t wg_size;
-        switch (path) {
-        case FA_COOPMAT2:
-            wg_size = ((small_rows && (D % 32) == 0) ? 256 : 128);
-            break;
-        case FA_COOPMAT1:
-            wg_size = (rows_cols[1] / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc
-            break;
-        default:
-            wg_size = scalar_flash_attention_workgroup_size;
-            break;
-        }
-
-        // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
-        // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
-        const uint32_t D_lsb = D ^ (D & (D-1));
-        uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4);
-
-        // Nvidia prefers shared memory use to load large tiles of K.
-        // Switch to loading from global memory when it would use too much shared memory.
-        // AMD prefers loading K directly from global memory
-        const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 ? 1 : 0;
-
-        return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem, flags};
-    };
-
 #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
         for (auto &fa : device->pipeline_flash_attn_f32_f16[TYPE]) { \
-            uint32_t HSK = fa.first.HSK; \
-            uint32_t HSV = fa.first.HSV; \
-            bool small_rows = fa.first.small_rows; \
-            bool small_cache = fa.first.small_cache; \
             FaCodePath path = fa.first.path; \
+            uint32_t Br = fa.first.Br; \
+            uint32_t Bc = fa.first.Bc; \
             bool aligned = fa.first.aligned; \
             bool f32acc = fa.first.f32acc; \
-            uint32_t flags = fa.first.flags; \
+            uint32_t fa_sgs = fa.first.subgroup_size; \
+            bool fa_ds = fa.first.subgroup_size == 0; \
             if (path == FAPATH) { \
                 if (aligned) { \
                     if (f32acc) { \
-                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _data,  "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0));     \
+                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _data,  "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0));     \
                     } else { \
-                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0));     \
+                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0));     \
                     } \
                 } else { \
                     if (f32acc) { \
-                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc"         #NAMELC, flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _data,  "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1,                                        true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0));     \
+                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc"         #NAMELC, flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _data,  "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1,  true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0));     \
                     } else { \
-                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc"         #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1,                                        true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0));     \
+                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc"         #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1,  true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0));     \
                     } \
                 } \
             } \
         }
 
-    CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
-    CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
-    CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
-    CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
+    if (device->flash_attention_fp16) {
+        CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
+        CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
+        CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
+        CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
+    } else {
+        CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32)
+        CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32)
+        CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
+        CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
+    }
 #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
     if (device->coopmat1_fa_support) {
         CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1)
@@ -4535,6 +4667,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
 }
 
 static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch);
+static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev);
 
 static vk_device ggml_vk_get_device(size_t idx) {
     VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")");
@@ -4751,6 +4884,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
             device->shader_core_count = sm_props.shaderSMCount;
         } else if (amd_shader_core_properties2) {
             device->shader_core_count = amd_shader_core_properties2_props.activeComputeUnitCount;
+        } else if (device->vendor_id == VK_VENDOR_ID_INTEL) {
+            device->shader_core_count = ggml_vk_intel_shader_core_count(device->physical_device);
         } else {
             device->shader_core_count = 0;
         }
@@ -4970,11 +5105,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
 
 #if defined(VK_KHR_cooperative_matrix)
         device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
-
-        // coopmat1 fa shader currently assumes 32 invocations per subgroup
-        device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support &&
-                                      device->subgroup_size_control && device->subgroup_min_size <= 32 &&
-                                      device->subgroup_max_size >= 32;
+        device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support;
 #endif
 
         if (coopmat2_support) {
@@ -5292,6 +5423,10 @@ static vk_device ggml_vk_get_device(size_t idx) {
             device->mmvq_mode = 1;
         }
 
+        // Driver issues with older AMD GPUs on Windows, see https://github.com/ggml-org/llama.cpp/pull/19625#issuecomment-3940840613
+        const bool is_amd_proprietary_gcn = device->vendor_id == VK_VENDOR_ID_AMD && device->architecture == AMD_GCN && device->driver_id == vk::DriverId::eAmdProprietary;
+        device->flash_attention_fp16 = device->fp16 && !is_amd_proprietary_gcn;
+
         return device;
     }
 
@@ -5542,6 +5677,10 @@ static void ggml_vk_instance_init() {
     vk_perf_logger_concurrent = getenv("GGML_VK_PERF_LOGGER_CONCURRENT") != nullptr;
     vk_enable_sync_logger = getenv("GGML_VK_SYNC_LOGGER") != nullptr;
     vk_memory_logger_enabled = getenv("GGML_VK_MEMORY_LOGGER") != nullptr;
+    const char* GGML_VK_PIPELINE_STATS = getenv("GGML_VK_PIPELINE_STATS");
+    if (GGML_VK_PIPELINE_STATS != nullptr) {
+        vk_pipeline_stats_filter = GGML_VK_PIPELINE_STATS;
+    }
     const char* GGML_VK_PERF_LOGGER_FREQUENCY = getenv("GGML_VK_PERF_LOGGER_FREQUENCY");
 
     if (GGML_VK_PERF_LOGGER_FREQUENCY != nullptr) {
@@ -8421,21 +8560,27 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
     }
 }
 
-static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool small_cache) {
+static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) {
+    GGML_UNUSED(f32acc);
     // Needs to be kept up to date on shader changes
-    GGML_UNUSED(hsv);
-    const uint32_t wg_size = scalar_flash_attention_workgroup_size;
-    const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv, small_cache);
-    const uint32_t Bc = scalar_flash_attention_Bc;
+    const uint32_t wg_size = params.workgroup_size;
+    const uint32_t Br = params.block_rows;
+    const uint32_t Bc = params.block_cols;
 
+    const uint32_t float_type_size = device->flash_attention_fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
+
+    // tmpsh is overestimated slightly
     const uint32_t tmpsh = wg_size * sizeof(float);
-    const uint32_t tmpshv4 = wg_size * 4 * sizeof(float);
+    const uint32_t tmpshv4 = wg_size * 4 * float_type_size;
+
+    const uint32_t masksh = Bc * (Br + 1) * float_type_size;
 
-    const uint32_t masksh = Bc * Br * sizeof(float);
+    const uint32_t Qf = Br * (hsk / 4 + 1) * 4 * float_type_size;
 
-    const uint32_t Qf = Br * (hsk / 4 + 2) * 4 * sizeof(float);
+    const uint32_t D = std::max(hsk, hsv);
+    const uint32_t kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
 
-    const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf;
+    const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh;
     const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
 
     VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
@@ -8443,18 +8588,17 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
     return supported;
 }
 
-static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) {
+static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) {
     // Needs to be kept up to date on shader changes
-    GGML_UNUSED(hsv);
-    const auto rows_cols = fa_rows_cols(FA_COOPMAT1, hsk, hsv, 0, kv_type, false, false);
-    const uint32_t Br = rows_cols[0];
-    const uint32_t Bc = rows_cols[1];
+    const uint32_t Br = params.block_rows;
+    const uint32_t Bc = params.block_cols;
 
     const uint32_t MatBr = 16, MatBc = 16;
 
     const uint32_t row_split = Bc / MatBc;
 
     const uint32_t hsk_pad = ROUNDUP_POW2(hsk, 16);
+    const uint32_t hsv_pad = ROUNDUP_POW2(hsv, 16);
 
     const uint32_t acctype = f32acc ? 4 : 2;
     const uint32_t f16vec4 = 8;
@@ -8470,17 +8614,19 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
     const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
     const uint32_t sfsh = Bc * sfshstride * acctype;
 
-    const bool k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256;
-    const uint32_t kshstride = (k_load_shmem ? hsk_pad : MatBr) / 4 + 2;
+    const uint32_t kvshstride = (params.shmem_staging ? std::max(hsk_pad, hsv_pad) : MatBr) / 4 + 2;
     const uint32_t vsh_stride = MatBc / 4 * row_split;
-    const uint32_t ksh = ((kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)) * f16vec4;
+    const uint32_t ksh = ((kvshstride >= vsh_stride) ? (Bc * kvshstride) : (Bc * vsh_stride)) * f16vec4;
+
+    const uint32_t osh_stride = params.row_split * MatBr / 4;
+    const uint32_t pvsh = MatBc * osh_stride * f16vec4;
 
     const uint32_t slope = Br * acctype;
 
-    const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + slope;
+    const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + pvsh + slope;
     const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
 
-    VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", kv_type=" << kv_type << ", total_size=" << total_size << ", supported=" << supported);
+    VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
 
     return supported;
 }
@@ -8538,48 +8684,18 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
     assert(q->type == GGML_TYPE_F32);
     assert(k->type == v->type);
 
-    FaCodePath path = ctx->device->coopmat2 ? FA_COOPMAT2 :
-                      ctx->device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
-
-    if (path == FA_COOPMAT1 && ctx->device->architecture == vk_device_architecture::NVIDIA_TURING) {
-        // Nvidia compiler bug, see https://github.com/ggml-org/llama.cpp/pull/19075#issuecomment-3820716090
-        path = FA_SCALAR;
-    }
-
-    if (path == FA_COOPMAT1) {
-        const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
-                                             (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
-
-        const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32, k->type);
-
-        if (!coopmat_shape_supported || !coopmat_shmem_supported) {
-            path = FA_SCALAR;
-        }
-    }
-
     uint32_t gqa_ratio = 1;
     uint32_t qk_ratio = neq2 / nek2;
     uint32_t workgroups_x = (uint32_t)neq1;
     uint32_t workgroups_y = (uint32_t)neq2;
     uint32_t workgroups_z = (uint32_t)neq3;
 
-    const bool small_cache = nek1 < 1024;
+    const bool f32acc = !ctx->device->flash_attention_fp16 || dst->op_params[3] == GGML_PREC_F32;
 
     // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
     // For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
-    uint32_t max_gqa;
-    switch (path) {
-    case FA_SCALAR:
-    case FA_COOPMAT1:
-        // We may switch from coopmat1 to scalar, so use the scalar limit for both
-        max_gqa = get_fa_scalar_num_large_rows(HSK, HSV, small_cache);
-        break;
-    case FA_COOPMAT2:
-        max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
-        break;
-    default:
-        GGML_ASSERT(0);
-    }
+    vk_fa_tuning_params tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, 512, KV, k->type, f32acc);
+    const uint32_t max_gqa = std::min(tuning_params.block_rows, 32u);
 
     if (N <= 8 && qk_ratio > 1 && qk_ratio <= max_gqa &&
         qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) {
@@ -8591,24 +8707,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
         workgroups_y /= gqa_ratio;
     }
 
-    bool small_rows = N <= get_fa_num_small_rows(path);
-
-    // coopmat1 does not actually support "small rows" (it needs 16 rows).
-    // So use scalar instead.
-    if (small_rows && path == FA_COOPMAT1) {
-        path = FA_SCALAR;
-    }
-
-    // scalar is faster than coopmat2 when N==1
-    if (N == 1 && path == FA_COOPMAT2) {
-        path = FA_SCALAR;
-    }
-
-    // with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory
-    if (path == FA_SCALAR &&
-        !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, small_cache)) {
-        small_rows = true;
-    }
+    tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, f32acc);
 
     const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
     uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
@@ -8622,18 +8721,16 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
         v_stride /= 4;
     }
 
-    uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows, small_cache);
+    const uint32_t alignment = tuning_params.block_cols;
     bool aligned = (KV % alignment) == 0 &&
                    // the "aligned" shader variant will forcibly align strides, for performance
                    (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
 
     // Need to use the coopmat2 variant that clamps loads when HSK/HSV aren't sufficiently aligned.
-    if (((HSK | HSV) % 16) != 0 && path == FA_COOPMAT2) {
+    if (((HSK | HSV) % 16) != 0 && tuning_params.path == FA_COOPMAT2) {
         aligned = false;
     }
 
-    bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
-
     float scale         = 1.0f;
     float max_bias      = 0.0f;
     float logit_softcap = 0.0f;
@@ -8648,12 +8745,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
 
     // Only use mask opt when the mask is fairly large. This hasn't been tuned extensively.
     bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768;
-
-    uint32_t flags = (use_mask_opt       ? 1 : 0) |
-                     (mask != nullptr    ? 2 : 0) |
-                     (logit_softcap != 0 ? 4 : 0);
-
-    vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, flags);
+    vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(tuning_params, HSK, HSV, aligned, f32acc,
+                                                                   mask != nullptr, use_mask_opt, logit_softcap != 0);
 
     vk_pipeline pipeline = nullptr;
 
@@ -8675,22 +8768,35 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
     uint32_t split_kv = KV;
     uint32_t split_k = 1;
 
+    // Intel Alchemist prefers more workgroups
+    const uint32_t shader_core_count_multiplier = (ctx->device->vendor_id == VK_VENDOR_ID_INTEL && ctx->device->architecture != INTEL_XE2) ? 2 : 1;
+
     // Use a placeholder core count if one isn't available. split_k is a big help for perf.
-    const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
+    const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count * shader_core_count_multiplier : 16;
+
+    const uint32_t Br = fa_pipeline_state.Br;
+    const uint32_t Bc = fa_pipeline_state.Bc;
+
+    GGML_ASSERT(Br == pipeline->wg_denoms[0]);
+    const uint32_t Tr = CEIL_DIV(N, Br);
 
     // Try to use split_k when KV is large enough to be worth the overhead.
-    // Must either be a single batch or be using gqa, we can't mix the two.
-    if (workgroups_x <= pipeline->wg_denoms[0] && (workgroups_x == 1 || gqa_ratio > 1)) {
-        // Try to run two workgroups per SM.
+    if (gqa_ratio > 1 && workgroups_x <= Br) {
         split_k = shader_core_count * 2 / (workgroups_x * workgroups_y * workgroups_z);
-        if (split_k > 1) {
-            // Try to evenly split KV into split_k chunks, but it needs to be a multiple
-            // of "align", so recompute split_k based on that.
-            split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment);
-            split_k = CEIL_DIV(KV, split_kv);
+    } else if (gqa_ratio <= 1) {
+        uint32_t total_wgs_no_split = Tr * workgroups_y * workgroups_z;
+        if (total_wgs_no_split < shader_core_count * 2) {
+            split_k = shader_core_count * 2 / total_wgs_no_split;
         }
     }
 
+    if (split_k > 1) {
+        // Try to evenly split KV into split_k chunks, but it needs to be a multiple
+        // of "align", so recompute split_k based on that.
+        split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment);
+        split_k = CEIL_DIV(KV, split_kv);
+    }
+
     // Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
     // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
     // For matrices, the order is (inner to outer) [HSV, ne1, k, ne2, ne3].
@@ -8704,10 +8810,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
         ggml_vk_preallocate_buffers(ctx, subctx);
     }
 
-    auto rows_cols = fa_rows_cols(path, HSK, HSV, !aligned, k->type, small_rows, small_cache);
-    const uint32_t Br = rows_cols[0];
-    const uint32_t Bc = rows_cols[1];
-
     const uint32_t mask_opt_num_dwords = CEIL_DIV(nem0, 16 * Bc);
     const uint64_t mask_opt_size = sizeof(uint32_t) * mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2 * nem3;
 
@@ -8787,15 +8889,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
         if (ctx->prealloc_split_k_need_sync) {
             ggml_vk_sync_buffers(ctx, subctx);
         }
-        workgroups_x *= pipeline->wg_denoms[0];
+
+        // We reuse workgroups_x to mean the number of splits, so we need to
+        // cancel out the divide by wg_denoms[0].
+        uint32_t dispatch_x;
+        if (gqa_ratio > 1) {
+            workgroups_x *= pipeline->wg_denoms[0];
+            dispatch_x = split_k * workgroups_x;
+        } else {
+            dispatch_x = Tr * split_k * pipeline->wg_denoms[0];
+        }
+
         vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
         ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
                                     {q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf, mask_opt_buf},
-                                    // We only use split_k when group query attention is enabled, which means
-                                    // there's no more than one tile of rows (i.e. workgroups_x would have been
-                                    // one). We reuse workgroups_x to mean the number of splits, so we need to
-                                    // cancel out the divide by wg_denoms[0].
-                                    pc, { split_k * workgroups_x, workgroups_y, workgroups_z });
+                                    pc, { dispatch_x, workgroups_y, workgroups_z });
 
         ggml_vk_sync_buffers(ctx, subctx);
         const vk_op_flash_attn_split_k_reduce_push_constants pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, split_k, (sinks != nullptr) };
@@ -15420,6 +15528,46 @@ static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDevicePrope
     }
 }
 
+static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev) {
+    VkPhysicalDeviceProperties2 props = vkdev.getProperties2();
+
+    if (props.properties.vendorID != VK_VENDOR_ID_INTEL) {
+        return 0;
+    }
+
+    const uint32_t device_id = props.properties.deviceID;
+
+    switch (device_id) {
+    case 0x56A6:  // A310
+        return 6;
+    case 0x5693:  // A370M
+    case 0x56A5:  // A380
+    case 0x56B1:  // Pro A40/A50
+        return 8;
+    case 0x5697:  // A530M
+        return 12;
+    case 0x5692:  // A550M
+    case 0x56B3:  // Pro A60
+        return 16;
+    case 0x56A2:  // A580
+        return 24;
+    case 0x5691:  // A730M
+    case 0x56A1:  // A750
+        return 28;
+    case 0x56A0:  // A770
+    case 0x5690:  // A770M
+        return 32;
+    case 0xE212:  // Pro B50
+        return 16;
+    case 0xE20C:  // B570
+        return 18;
+    case 0xE20B:  // B580
+        return 20;
+    default:
+        return 0;
+    }
+}
+
 // checks
 
 #ifdef GGML_VULKAN_CHECK_RESULTS
@@ -16096,7 +16244,7 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph *
         ggml_vk_print_graph_origin(tensor, done);
     }
 
-    if (avg_err > 0.5 || std::isnan(avg_err)) {
+    if (avg_err > 0.01 || std::isnan(avg_err)) {
         std::cerr << "ERROR: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl;
         std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl;
         if (src0 != nullptr) {
index 0735f678549aaa1f1eeb54e9822055bc3195071a..135ab1ad625538b990ba729c014b70a20df59a40 100644 (file)
@@ -3,9 +3,13 @@
 #extension GL_EXT_control_flow_attributes : enable
 #extension GL_EXT_shader_16bit_storage : require
 
-#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
 
+#ifdef FLOAT16
+#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
+#extension GL_EXT_shader_subgroup_extended_types_float16 : require
+#endif
+
 #extension GL_KHR_shader_subgroup_shuffle : enable
 #extension GL_KHR_shader_subgroup_vote : enable
 
 const uint32_t HSK_per_thread = HSK / D_split;
 const uint32_t HSV_per_thread = HSV / D_split;
 
-const uint32_t cols_per_iter = WorkGroupSize / D_split;
+const uint32_t rows_per_thread = Br / row_split;
+const uint32_t cols_per_iter = WorkGroupSize / D_split / row_split;
 const uint32_t cols_per_thread = Bc / cols_per_iter;
+const uint32_t num_subgroups = SubGroupSize == 0 ? 0 : WorkGroupSize / SubGroupSize;
 
 
 layout (binding = 0) readonly buffer Q {float data_q[];};
@@ -27,20 +33,22 @@ layout (binding = 2) readonly buffer V {float16_t data_v[];};
 layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
 layout (binding = 3) readonly buffer M {float16_t data_m[];};
 
-// Store the output when doing grouped query attention.
-// Rows index by Q's dimension 2, and the first N rows are valid.
-D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
-{
-    uint32_t offset = (iq2 + r) * HSV + c;
-    data_o[o_offset + offset] = D_TYPE(elem);
-    return elem;
-}
+// If SubGroupSize is set to 0 then only use shmem reductions
+const uint32_t tmpsh_size = (SubGroupSize > 0) ? (row_split == 1 ? num_subgroups * D_split : num_subgroups) : WorkGroupSize;
+shared float tmpsh[tmpsh_size];
+shared FLOAT_TYPEV4 tmpshv4[tmpsh_size];
 
-shared FLOAT_TYPE tmpsh[WorkGroupSize];
-shared vec4 tmpshv4[WorkGroupSize];
+const uint32_t masksh_stride = Br + 1;
+shared FLOAT_TYPE masksh[Bc * masksh_stride];
 
-shared float masksh[Bc][Br];
-shared vec4 Qf[Br][HSK / 4];
+const uint32_t qf_stride = HSK / 4 + 1;
+shared FLOAT_TYPEV4 Qf[Br * qf_stride];
+
+const uint32_t D = HSK > HSV ? HSK : HSV;
+const uint32_t kvsh_stride = D / 4 + 1;
+shared FLOAT_TYPEV4 kvsh[SHMEM_STAGING != 0 ? Bc * kvsh_stride : 1];
+
+shared vec4 occupancy_limiter[LIMIT_OCCUPANCY_SHMEM > 0 ? LIMIT_OCCUPANCY_SHMEM : 1];
 
 void main() {
 #ifdef NEEDS_INIT_IQ_SHMEM
@@ -50,8 +58,24 @@ void main() {
     init_indices();
 
     const uint32_t tid = gl_LocalInvocationIndex;
+    const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;
+    const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;
+    const uint32_t rowgroup_tid = gl_LocalInvocationIndex % threads_per_rowgroup;
     const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
-    const uint32_t col_tid = gl_LocalInvocationIndex / D_split;
+    const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split;
+
+    if (LIMIT_OCCUPANCY_SHMEM > 0) {
+        // This just exists to avoid the occupancy_limiter array getting optimized out
+        occupancy_limiter[tid] = vec4(tid);
+
+        barrier();
+
+        if (occupancy_limiter[tid] == vec4(99999.0)) {
+            data_ov4[0] = D_TYPEV4(occupancy_limiter[tid]);
+        }
+    }
+
+#define tile_row(r) (row_tid * rows_per_thread + (r))
 
     uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02 + iq3*p.nb03) / 4;
 
@@ -60,37 +84,37 @@ void main() {
         uint32_t r = (idx + tid) / (HSK / 4);
         if (r < Br && d < HSK / 4 &&
             i * Br + r < N) {
-            Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale;
+            Qf[r * qf_stride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale);
         }
     }
     barrier();
 
-    vec4 Of[Br][HSV_per_thread / 4];
+    FLOAT_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4];
     [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-            Of[r][d] = vec4(0.0);
+        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+            Of[r][d] = FLOAT_TYPEV4(0.0);
         }
     }
 
-    float Lf[Br], Mf[Br];
+    float Lf[rows_per_thread], Mf[rows_per_thread];
 
     // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
     const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
 
-    [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
         Lf[r] = 0;
         Mf[r] = NEG_FLT_MAX_OVER_2;
     }
 
-    float slope[Br];
-    [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-        slope[r] = 1.0;
+    ACC_TYPE slope[rows_per_thread];
+    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+        slope[r] = ACC_TYPE(1.0);
     }
 
     // ALiBi
     if (p.max_bias > 0.0f) {
-        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-            slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);
+        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+            slope[r] = perElemOpComputeSlope(tile_row(r), col_tid, ACC_TYPE(0), iq2);
         }
     }
 
@@ -113,75 +137,141 @@ void main() {
 
     uint32_t mask_opt = 0;
     uint32_t mask_opt_idx = ~0;
+    uint32_t mask_opt_bits = 0;
 
     [[dont_unroll]]
     for (uint32_t j = start_j; j < end_j; ++j) {
+        if (MASK_ENABLE) {
+            if (USE_MASK_OPT && mask_opt_idx != j / 16) {
+                mask_opt_idx = j / 16;
+                mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
+            }
+            mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
+            if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
+                // skip this block
+                continue;
+            }
+            // Only load if the block is not all zeros
+            if (mask_opt_bits != MASK_OPT_ALL_ZERO) {
+                bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
 
-        if (USE_MASK_OPT && mask_opt_idx != j / 16) {
-            mask_opt_idx = j / 16;
-            mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
-        }
-        uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
-        if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
-            // skip this block
-            continue;
-        }
-        // Only load if the block is not all zeros
-        if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
-            bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
-
-            float max_mask = NEG_FLT_MAX_OVER_2;
-            [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
-                uint32_t c = (idx + tid) % Bc;
-                uint32_t r = (idx + tid) / Bc;
-                if (idx + tid < Bc * Br) {
-                    if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
-                        float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
-                        masksh[c][r] = m;
-                        max_mask = max(max_mask, m);
-                    } else {
-                        masksh[c][r] = float(0);
+                float max_mask = NEG_FLT_MAX_OVER_2;
+                barrier();
+                [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
+                    uint32_t c = (idx + tid) % Bc;
+                    uint32_t r = (idx + tid) / Bc;
+                    if (idx + tid < Bc * Br) {
+                        if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
+                            FLOAT_TYPE m = FLOAT_TYPE(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
+                            masksh[c * masksh_stride + r] = m;
+                            max_mask = max(max_mask, float(m));
+                        } else {
+                            masksh[c * masksh_stride + r] = FLOAT_TYPE(0);
+                        }
                     }
                 }
-            }
-            // skip the block if the mask is entirely -inf
-            bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
-            barrier();
-            if (gl_SubgroupInvocationID == 0) {
-                tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
-            }
-            barrier();
-            [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
-                max_mask = max(max_mask, tmpsh[s]);
-            }
-            if (max_mask <= NEG_FLT_MAX_OVER_2) {
-                continue;
+                // skip the block if the mask is entirely -inf
+                bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
+                barrier();
+                if (gl_SubgroupInvocationID == 0) {
+                    tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
+                }
+                barrier();
+                [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
+                    max_mask = max(max_mask, tmpsh[s]);
+                }
+                if (max_mask <= NEG_FLT_MAX_OVER_2) {
+                    continue;
+                }
             }
         }
 
-        float Sf[Br][cols_per_thread];
-        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+        ACC_TYPE Sf[rows_per_thread][cols_per_thread];
+        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
             [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
-                Sf[r][c] = 0.0;
+                Sf[r][c] = ACC_TYPE(0.0);
             }
         }
 
+        if (SHMEM_STAGING != 0) {
+            barrier();
+            [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
+                uint32_t d = (idx + tid) % (HSK / 4);
+                uint32_t c = (idx + tid) / (HSK / 4);
+                if (idx + gl_WorkGroupSize.x <= Bc * HSK / 4 || c < Bc) {
+                    FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0);
+                    if (!KV_bounds_check || j * Bc + c < KV) {
+#if BLOCK_SIZE > 1
+                        uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
+                        uint ib = coord / BLOCK_SIZE;
+                        uint iqs = (coord % BLOCK_SIZE);
+                        K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
+#else
+                        K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
+#endif
+                    }
 
-        [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
-            if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
-                continue;
+                    kvsh[c * kvsh_stride + d] = K_Tf;
+                }
             }
+            barrier();
+        }
+
+        // More d iterations means Q register caching becomes relevant
+        // Few iterations means the additional registers needed are worse than the speed-up from caching
+        if (HSK_per_thread / 4 > 4) {
             [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
+                FLOAT_TYPEV4 Q_cache[rows_per_thread];
+                [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                    Q_cache[r] = Qf[tile_row(r) * qf_stride + d * D_split + d_tid];
+                }
+
+                [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+                    if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
+                        continue;
+                    }
+
+                    FLOAT_TYPEV4 K_Tf;
+                    if (SHMEM_STAGING != 0) {
+                        K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
+                    } else {
 #if BLOCK_SIZE > 1
-                uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
-                uint ib = coord / BLOCK_SIZE;
-                uint iqs = (coord % BLOCK_SIZE);
-                vec4 K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
+                        uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
+                        uint ib = coord / BLOCK_SIZE;
+                        uint iqs = (coord % BLOCK_SIZE);
+                        K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
 #else
-                vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
+                        K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
 #endif
-                [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-                    Sf[r][c] += dot(Qf[r][d * D_split + d_tid], K_Tf);
+                    }
+                    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                        Sf[r][c] += ACC_TYPE(dot(Q_cache[r], K_Tf));
+                    }
+                }
+            }
+        } else {
+            [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
+                if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
+                    continue;
+                }
+
+                [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) {
+                    FLOAT_TYPEV4 K_Tf;
+                    if (SHMEM_STAGING != 0) {
+                        K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
+                    } else {
+#if BLOCK_SIZE > 1
+                        uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
+                        uint ib = coord / BLOCK_SIZE;
+                        uint iqs = (coord % BLOCK_SIZE);
+                        K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
+#else
+                        K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]);
+#endif
+                    }
+                    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                        Sf[r][c] += ACC_TYPE(dot(Qf[tile_row(r) * qf_stride + d * D_split + d_tid], K_Tf));
+                    }
                 }
             }
         }
@@ -189,89 +279,109 @@ void main() {
         [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
             // Compute sum across the D_split
             [[unroll]] for (uint s = D_split / 2; s > 0; s >>= 1) {
-                [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+                [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
                     Sf[r][c] += subgroupShuffleXor(Sf[r][c], s);
                 }
             }
         }
 
         if (LOGIT_SOFTCAP) {
-            [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
                 [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
-                    Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]);
+                    Sf[r][c] = ACC_TYPE(p.logit_softcap * tanh(Sf[r][c]));
                 }
             }
         }
 
         if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
             [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
-                [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-                    float mvf = masksh[c * cols_per_iter + col_tid][r];
+                [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                    FLOAT_TYPE mvf = masksh[(c * cols_per_iter + col_tid) * masksh_stride + tile_row(r)];
 
                     Sf[r][c] += slope[r]*mvf;
                 }
             }
-            barrier();
         }
 
-        float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br];
-        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-            rowmaxf[r] = NEG_FLT_MAX_OVER_2;
+        float eMf[rows_per_thread];
+        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+            float rowmaxf = NEG_FLT_MAX_OVER_2;
             [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
                 if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
                     continue;
                 }
-                rowmaxf[r] = max(rowmaxf[r], Sf[r][c]);
+                rowmaxf = max(rowmaxf, float(Sf[r][c]));
             }
-            Moldf[r] = Mf[r];
+            float Moldf = Mf[r];
 
             // M = max(rowmax, Mold)
             // P = e^(S - M)
             // eM = e^(Mold - M)
-            Mf[r] = max(rowmaxf[r], Moldf[r]);
-            [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
-                Pf[r][c] = exp(Sf[r][c] - Mf[r]);
-            }
-            eMf[r] = exp(Moldf[r] - Mf[r]);
+            Mf[r] = max(rowmaxf, Moldf);
+            eMf[r] = exp(Moldf - Mf[r]);
+            Lf[r] = eMf[r]*Lf[r];
+        }
 
-            // Compute sum across row of P
-            rowsumf[r] = 0.0;
-            [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
-                if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
-                    continue;
-                }
-                rowsumf[r] += Pf[r][c];
+        [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
+            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                Of[r][d] = FLOAT_TYPE(eMf[r]) * Of[r][d];
             }
-
-            Lf[r] = eMf[r]*Lf[r] + rowsumf[r];
         }
 
-        [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-            [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-                Of[r][d] = eMf[r] * Of[r][d];
+        if (SHMEM_STAGING != 0) {
+            barrier();
+            [[unroll]] for (uint32_t idx = 0; idx < Bc * HSV / 4; idx += gl_WorkGroupSize.x) {
+                uint32_t d = (idx + tid) % (HSV / 4);
+                uint32_t c = (idx + tid) / (HSV / 4);
+                if (idx + gl_WorkGroupSize.x <= Bc * HSV / 4 || c < Bc) {
+                    FLOAT_TYPEV4 V_Tf = FLOAT_TYPEV4(0);
+                    if (!KV_bounds_check || j * Bc + c < KV) {
+#if BLOCK_SIZE > 1
+                        uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d;
+                        uint ib = coord / BLOCK_SIZE;
+                        uint iqs = (coord % BLOCK_SIZE);
+                        V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
+#else
+                        V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
+#endif
+                    }
+
+                    kvsh[c * kvsh_stride + d] = V_Tf;
+                }
             }
+            barrier();
         }
 
         [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
             if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
                 continue;
             }
+
+            FLOAT_TYPE Pf[rows_per_thread];
+            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                Pf[r] = FLOAT_TYPE(exp(float(Sf[r][c]) - Mf[r]));
+                Lf[r] += Pf[r];
+            }
+
             [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
+                FLOAT_TYPEV4 Vf;
+                if (SHMEM_STAGING != 0) {
+                    Vf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)];
+                } else {
 #if BLOCK_SIZE > 1
-                uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
-                uint ib = coord / BLOCK_SIZE;
-                uint iqs = (coord % BLOCK_SIZE);
-                vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
+                    uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
+                    uint ib = coord / BLOCK_SIZE;
+                    uint iqs = (coord % BLOCK_SIZE);
+                    Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
 #else
-                vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
+                    Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
 #endif
-                [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-                    Of[r][d] += Pf[r][c] * Vf;
+                }
+                [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                    Of[r][d] += FLOAT_TYPEV4(Pf[r] * Vf);
                 }
             }
         }
-
-        barrier();
     }
 
     // prevent race on tmpsh
@@ -279,58 +389,108 @@ void main() {
 
     // reduce across threads
 
-    [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-        float rowmaxf, eMf;
+    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+        float rowmaxf = Mf[r];
 
-        tmpsh[tid] = Mf[r];
         // Compute max across the row
-        barrier();
-        [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
-            if (tid < s) {
-                tmpsh[tid] = max(tmpsh[tid], tmpsh[tid + s]);
+        if (SubGroupSize > 0) {
+            [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
+                rowmaxf = max(rowmaxf, subgroupShuffleXor(rowmaxf, s));
             }
+            if (row_split == 1) {
+                // Reduce inside workgroup with shmem
+                barrier();
+                if (gl_SubgroupInvocationID == d_tid) {
+                    tmpsh[gl_SubgroupID * D_split + d_tid] = rowmaxf;
+                }
+                barrier();
+                rowmaxf = tmpsh[d_tid];
+                [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
+                    rowmaxf = max(rowmaxf, tmpsh[s * D_split + d_tid]);
+                }
+            }
+        } else {
+            barrier();
+            tmpsh[tid] = rowmaxf;
             barrier();
+            [[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) {
+                if (rowgroup_tid < s) {
+                    tmpsh[tid] = max(tmpsh[tid], tmpsh[tid ^ s]);
+                }
+                barrier();
+            }
+            rowmaxf = tmpsh[row_tid * threads_per_rowgroup + d_tid];
         }
-        rowmaxf = tmpsh[d_tid];
-        barrier();
 
         float Moldf = Mf[r];
 
         // M = max(rowmax, Mold)
         // eM = e^(Mold - M)
         Mf[r] = max(rowmaxf, Moldf);
-        eMf = exp(Moldf - Mf[r]);
+        float eMf = exp(Moldf - Mf[r]);
 
         Lf[r] = eMf*Lf[r];
 
-        tmpsh[tid] = Lf[r];
-
         // Compute sum across the row
-        barrier();
-        [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
-            if (tid < s) {
-                tmpsh[tid] = tmpsh[tid] + tmpsh[tid + s];
+        if (SubGroupSize > 0) {
+            [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
+                Lf[r] += subgroupShuffleXor(Lf[r], s);
             }
+            if (row_split == 1) {
+                barrier();
+                if (gl_SubgroupInvocationID == d_tid) {
+                    tmpsh[gl_SubgroupID * D_split + d_tid] = Lf[r];
+                }
+                barrier();
+                Lf[r] = tmpsh[d_tid];
+                [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
+                    Lf[r] += tmpsh[s * D_split + d_tid];
+                }
+            }
+        } else {
             barrier();
+            tmpsh[tid] = Lf[r];
+            barrier();
+            [[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) {
+                if (rowgroup_tid < s) {
+                    tmpsh[tid] = tmpsh[tid] + tmpsh[tid ^ s];
+                }
+                barrier();
+            }
+            Lf[r] = tmpsh[row_tid * threads_per_rowgroup + d_tid];
         }
-        Lf[r] = tmpsh[d_tid];
-        barrier();
 
         [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
+            Of[r][d] = FLOAT_TYPE(eMf) * Of[r][d];
 
-            Of[r][d] = eMf * Of[r][d];
-            tmpshv4[tid] = Of[r][d];
-
-            barrier();
-            [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) {
-                if (tid < s) {
-                    Of[r][d] += tmpshv4[tid + s];
-                    tmpshv4[tid] = Of[r][d];
+            if (SubGroupSize > 0) {
+                [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) {
+                    Of[r][d] += subgroupShuffleXor(Of[r][d], s);
+                }
+                if (row_split == 1) {
+                    barrier();
+                    if (gl_SubgroupInvocationID == d_tid) {
+                        tmpshv4[gl_SubgroupID * D_split + d_tid] = Of[r][d];
+                    }
+                    barrier();
+                    Of[r][d] = tmpshv4[d_tid];
+                    [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) {
+                        Of[r][d] += tmpshv4[s * D_split + d_tid];
+                    }
                 }
+            } else {
+                barrier();
+                tmpshv4[tid] = Of[r][d];
                 barrier();
+                [[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) {
+                    if (rowgroup_tid < s) {
+                        Of[r][d] += tmpshv4[tid ^ s];
+                        tmpshv4[tid] = Of[r][d];
+                    }
+                    barrier();
+                }
+                Of[r][d] = tmpshv4[row_tid * threads_per_rowgroup + d_tid];
             }
-            Of[r][d] = tmpshv4[d_tid];
-            barrier();
         }
     }
 
@@ -338,33 +498,53 @@ void main() {
     // If there is split_k, then the split_k resolve shader does the final
     // division by L. Store the intermediate O value and per-row m and L values.
     if (p.k_num > 1) {
-        // note: O and Q have swapped coord 1,2.
-        uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
-
-        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-            if (r < N) {
-                [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-                    [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
-                        perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
+        if (p.gqa_ratio > 1) {
+            // note: O and Q have swapped coord 1,2.
+            uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)) / 4;
+
+            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                const uint row = tile_row(r);
+                if (row < N) {
+                    [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
+                        gqaStore(row, d * D_split + d_tid, Of[r][d], o_offset, iq2, N);
                     }
                 }
             }
-        }
 
-        o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
-        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-            if (r < N) {
-                perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
-                perElemOpStoreCol0(r, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
+            o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
+            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                const uint row = tile_row(r);
+                if (row < N) {
+                    perElemOpStoreCol0(row, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
+                    perElemOpStoreCol0(row, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
+                }
             }
-        }
+        } else {
+            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                const uint row = tile_row(r);
+                const uint global_row = i * Br + row;
+
+                if (global_row < N) {
+                    uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)) / 4;
 
+                    [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
+                        data_ov4[o_offset + iq2 * HSV/4 + d * D_split + d_tid] = D_TYPEV4(Of[r][d]);
+                    }
+                }
+
+                if (global_row < N && d_tid == 0 && col_tid == 0) {
+                    uint32_t lm_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3));
+                    data_o[lm_offset + iq2] = D_TYPE(Lf[r]);
+                    data_o[lm_offset + p.ne1 + iq2] = D_TYPE(Mf[r]);
+                }
+            }
+        }
         return;
     }
 
     if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
-        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-            float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2);
+        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+            float sink = perElemOpGetSink(tile_row(r), 0u, ACC_TYPE(0), iq2);
 
             float ms = 1.0f;
             float vs = 1.0f;
@@ -373,7 +553,7 @@ void main() {
                 ms = exp(Mf[r] - sink);
 
                 [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-                    Of[r][d] *= ms;
+                    Of[r][d] *= FLOAT_TYPE(ms);
                 }
             } else {
                 vs = exp(sink - Mf[r]);
@@ -383,39 +563,37 @@ void main() {
         }
     }
 
-    float Lfrcp[Br];
-    [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
+    float Lfrcp[rows_per_thread];
+    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
         Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]);
     }
 
     [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-            Of[r][d] *= Lfrcp[r];
-#if defined(ACC_TYPE_MAX)
-            Of[r][d] = clamp(Of[r][d], -vec4(ACC_TYPE_MAX), vec4(ACC_TYPE_MAX));
+        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+            Of[r][d] *= FLOAT_TYPE(Lfrcp[r]);
+#if defined(FLOAT_TYPE_MAX)
+            Of[r][d] = clamp(Of[r][d], -FLOAT_TYPE_MAX, FLOAT_TYPE_MAX);
 #endif
         }
     }
 
-    uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV;
+    uint32_t o_offset = (gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV) / 4;
 
     if (p.gqa_ratio > 1) {
-        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-            if (r < N) {
+        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+            const uint row = tile_row(r);
+            if (row < N) {
                 [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-                    [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
-                        perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N);
-                    }
+                    gqaStore(row, d * D_split + d_tid, Of[r][d], o_offset, iq2, N);
                 }
             }
         }
     } else {
-        [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
-            if (i * Br + r < N) {
+        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+            const uint row = tile_row(r);
+            if (i * Br + row < N) {
                 [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-                    [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
-                        data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
-                    }
+                    data_ov4[o_offset + (iq2 * HSV + (i * Br + row) * p.ne1 * HSV) / 4 + d * D_split + d_tid] = D_TYPEV4(Of[r][d]);
                 }
             }
         }
index 4142c1e6eaac4cb0edccfce1d9c46306316d00d0..d444542b5336fb6923985a69c98f10fae31e4d3a 100644 (file)
@@ -1,16 +1,18 @@
 
 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
-layout (constant_id = 0) const uint32_t WorkGroupSize = 128;
-layout (constant_id = 1) const uint32_t Br = 1;
-layout (constant_id = 2) const uint32_t Bc = 32;
-layout (constant_id = 3) const uint32_t HSK = 32;
-layout (constant_id = 4) const uint32_t HSV = 32;
-layout (constant_id = 5) const uint32_t Clamp = 0;
-layout (constant_id = 6) const uint32_t D_split = 16;
-layout (constant_id = 7) const uint32_t SubGroupSize = 32;
-layout (constant_id = 8) const uint32_t K_LOAD_SHMEM = 0;
-layout (constant_id = 9) const uint32_t Flags = 0;
+layout (constant_id =  0) const uint32_t WorkGroupSize = 128;
+layout (constant_id =  1) const uint32_t Br = 1;
+layout (constant_id =  2) const uint32_t Bc = 32;
+layout (constant_id =  3) const uint32_t HSK = 32;
+layout (constant_id =  4) const uint32_t HSV = 32;
+layout (constant_id =  5) const uint32_t Clamp = 0;
+layout (constant_id =  6) const uint32_t D_split = 16;
+layout (constant_id =  7) const uint32_t row_split = 1;
+layout (constant_id =  8) const uint32_t SubGroupSize = 32;
+layout (constant_id =  9) const uint32_t SHMEM_STAGING = 0;
+layout (constant_id = 10) const uint32_t Flags = 0;
+layout (constant_id = 11) const uint32_t LIMIT_OCCUPANCY_SHMEM = 0;
 
 const bool USE_MASK_OPT  = (Flags & 1) != 0;
 const bool MASK_ENABLE   = (Flags & 2) != 0;
@@ -69,6 +71,7 @@ layout (push_constant) uniform parameter {
 layout (binding = 4) readonly buffer S {float data_s[];};
 
 layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
+layout (binding = 5) writeonly buffer OV4 {D_TYPEV4 data_ov4[];};
 
 layout (binding = 6) readonly buffer MO {uint32_t data_mask_opt[];};
 
@@ -94,12 +97,12 @@ layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16
 #define BLOCK_SIZE 4
 #define BLOCK_BYTE_SIZE 16
 
-vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
+FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
     // iqs is currently always zero in the flash attention shaders
     if (binding_idx == BINDING_IDX_K) {
-        return k_packed.k_data_packed[a_offset + ib];
+        return FLOAT_TYPEV4(k_packed.k_data_packed[a_offset + ib]);
     } else {
-        return v_packed.v_data_packed[a_offset + ib];
+        return FLOAT_TYPEV4(v_packed.v_data_packed[a_offset + ib]);
     }
 }
 #endif
@@ -107,7 +110,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
 #if defined(DATA_A_Q4_0)
 #define BLOCK_BYTE_SIZE 18
 
-vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
+FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
     if (binding_idx == BINDING_IDX_K) {
         uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
         uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
@@ -115,7 +118,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
         vui_lo >>= shift;
         vui_hi >>= shift;
 
-        return float(k_packed.k_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
+        return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - FLOAT_TYPE(8.0f));
     } else {
         uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
         uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
@@ -123,24 +126,24 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
         vui_lo >>= shift;
         vui_hi >>= shift;
 
-        return float(v_packed.v_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
+        return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - FLOAT_TYPE(8.0f));
     }
 }
 #endif
 
 #if defined(DATA_A_Q8_0)
 #define BLOCK_BYTE_SIZE 34
-vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
+FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
     if (binding_idx == BINDING_IDX_K) {
         const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
         const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
 
-        return float(k_packed.k_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
+        return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y);
     } else {
         const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
         const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
 
-        return float(v_packed.v_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
+        return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y);
     }
 }
 #endif
@@ -189,10 +192,16 @@ void init_indices()
     KV = p.KV;
 
     if (p.k_num > 1) {
-        i = 0;
-        // batch and split_k share gl_WorkGroupID.x
-        gqa_iq1 = gl_WorkGroupID.x / p.k_num;
-        split_k_index = gl_WorkGroupID.x % p.k_num;
+        if (p.gqa_ratio > 1) {
+            i = 0;
+            // batch and split_k share gl_WorkGroupID.x
+            gqa_iq1 = gl_WorkGroupID.x / p.k_num;
+            split_k_index = gl_WorkGroupID.x % p.k_num;
+        } else {
+            gqa_iq1 = 0;
+            split_k_index = gl_WorkGroupID.x % p.k_num;
+            i = gl_WorkGroupID.x / p.k_num;
+        }
     } else if (p.gqa_ratio > 1) {
         i = 0;
         gqa_iq1 = gl_WorkGroupID.x;
@@ -244,3 +253,11 @@ void init_indices()
 // Bias applied to softmax to stay in fp16 range.
 // Based on ggml-cuda issue https://github.com/ggml-org/llama.cpp/issues/18606
 const float FATTN_KQ_MAX_OFFSET = 3.0f*0.6931f;
+
+// Store the output when doing grouped query attention.
+// Rows index by Q's dimension 2, and the first N rows are valid.
+void gqaStore(const in uint32_t r, const in uint32_t c, const in FLOAT_TYPEV4 elems, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
+{
+    uint32_t offset = (iq2 + r) * HSV / 4 + c;
+    data_ov4[o_offset + offset] = D_TYPEV4(elems);
+}
index 19630972dafb56f47d76d8084f131ed4deb42611..526e8da384e1bd2899cb14e3bc1cc21a12060ad0 100644 (file)
@@ -19,7 +19,6 @@
 const uint32_t MatBr = 16;
 const uint32_t MatBc = 16;
 
-const uint32_t row_split = Bc / MatBc;
 const uint32_t rows_per_thread = Br / row_split;
 const uint32_t cols_per_iter = gl_WorkGroupSize.x / row_split;
 const uint32_t cols_per_thread = Bc / cols_per_iter;
@@ -33,15 +32,6 @@ layout (binding = 2) readonly buffer V {float16_t data_v[];};
 layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];};
 layout (binding = 3) readonly buffer M {float16_t data_m[];};
 
-// Store the output when doing grouped query attention.
-// Rows index by Q's dimension 2, and the first N rows are valid.
-D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
-{
-    uint32_t offset = (iq2 + r) * HSV + c;
-    data_o[o_offset + offset] = D_TYPE(elem);
-    return elem;
-}
-
 shared float tmpsh[row_split];
 
 const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
@@ -54,10 +44,14 @@ shared f16vec4 Psh[Bc * psh_stride];
 const uint32_t sfshstride = (HSK <= 128) ? (Br / 4 + 2) : Br / 4;
 shared ACC_TYPEV4 sfsh[Bc * sfshstride];
 
-const uint32_t kshstride = (K_LOAD_SHMEM != 0 ? HSK_pad : MatBr) / 4 + 2; // in units of f16vec4
+const uint32_t D_pad = HSK_pad > HSV_pad ? HSK_pad : HSV_pad;
+const uint32_t kvsh_stride = (SHMEM_STAGING != 0 ? D_pad : MatBr) / 4 + 2; // in units of f16vec4
 const uint v_cols = MatBc / 4 * row_split; // total cols, 4 vec4s per MatBc * number of subgroups
 const uint vsh_stride = v_cols;
-shared f16vec4 ksh[(kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)];
+shared f16vec4 kvsh[(kvsh_stride >= vsh_stride) ? (Bc * kvsh_stride) : (Bc * vsh_stride)];
+
+const uint32_t osh_stride = row_split * MatBr / 4;
+shared f16vec4 pvsh[MatBc * osh_stride];
 
 shared ACC_TYPE slope[Br];
 
@@ -84,11 +78,6 @@ void main() {
                 Qf[i + tid] = f16vec4(0);
             }
         }
-        [[unroll]] for (uint i = 0; i < Bc * kshstride; i += gl_WorkGroupSize.x) {
-            if (i + tid < Bc * kshstride) {
-                ksh[i + tid] = f16vec4(0);
-            }
-        }
         barrier();
     }
 
@@ -104,10 +93,10 @@ void main() {
     }
     barrier();
 
-    ACC_TYPEV4 Of[rows_per_thread][d_per_thread];
+    f16vec4 Of[rows_per_thread][d_per_thread];
     [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
         [[unroll]] for (uint32_t d = 0; d < d_per_thread; ++d) {
-            Of[r][d] = ACC_TYPEV4(0.0);
+            Of[r][d] = f16vec4(0.0);
         }
     }
 
@@ -153,22 +142,22 @@ void main() {
 
     uint32_t mask_opt = 0;
     uint32_t mask_opt_idx = ~0;
+    uint32_t mask_opt_bits = 0;
+    f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize];
 
     [[dont_unroll]]
     for (uint32_t j = start_j; j < end_j; ++j) {
 
-        f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize];
         [[unroll]] for (uint32_t idx = 0; idx < mask_cache.length(); ++idx) {
             mask_cache[idx] = f16vec4(0);
         }
 
         if (MASK_ENABLE) {
-
             if (USE_MASK_OPT && mask_opt_idx != j / 16) {
                 mask_opt_idx = j / 16;
                 mask_opt = data_mask_opt[mo_offset + mask_opt_idx];
             }
-            uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
+            mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3;
             if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) {
                 // skip this block
                 continue;
@@ -231,24 +220,24 @@ void main() {
             }
         }
 
-        if (K_LOAD_SHMEM != 0) {
-            [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
-                uint32_t d = (idx + tid) % (HSK / 4);
-                uint32_t c = (idx + tid) / (HSK / 4);
-                if (c < Bc && d < HSK / 4) {
+        if (SHMEM_STAGING != 0) {
+            [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK_pad / 4; idx += gl_WorkGroupSize.x) {
+                uint32_t d = (idx + tid) % (HSK_pad / 4);
+                uint32_t c = (idx + tid) / (HSK_pad / 4);
+                if (idx + gl_WorkGroupSize.x <= Bc * HSK_pad / 4 || c < Bc) {
                     f16vec4 K_Tf = f16vec4(0);
-                    if (!KV_bounds_check || j * Bc + c < KV) {
+                    if ((!KV_bounds_check || j * Bc + c < KV) && (HSK == HSK_pad || d < HSK / 4)) {
 #if BLOCK_SIZE > 1
                         uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
                         uint ib = coord / BLOCK_SIZE;
                         uint iqs = (coord % BLOCK_SIZE);
-                        K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
+                        K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
 #else
                         K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
 #endif
                     }
 
-                    ksh[c * kshstride + d] = K_Tf;
+                    kvsh[c * kvsh_stride + d] = K_Tf;
                 }
             }
             barrier();
@@ -262,7 +251,11 @@ void main() {
         coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
 
         [[unroll]] for (uint32_t d = 0; d < HSK_pad / 16; ++d) {
-            if (K_LOAD_SHMEM == 0) {
+            // If SHMEM_STAGING is set, a Bc * HSK_pad size tile of K is loaded to shmem
+            // If not, f16 K is loaded directly from global memory if aligned, otherwise
+            // staged through a Bc * MatBr size staging buffer.
+            // If K is not type f16, then it is always staged for dequantization.
+            if (SHMEM_STAGING == 0) {
 #if BLOCK_SIZE == 1
             if (KV_bounds_check || d * 16 + 16 > HSK) {
 #endif
@@ -277,13 +270,13 @@ void main() {
                         uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE + d * 16 + col_vec * 4;
                         uint ib = coord / BLOCK_SIZE;
                         uint iqs = (coord % BLOCK_SIZE);
-                        K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
+                        K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K);
 #else
                         K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]);
 #endif
                     }
 
-                    ksh[row * kshstride + col_vec] = K_Tf;
+                    kvsh[row * kvsh_stride + col_vec] = K_Tf;
                 }
             }
             barrier();
@@ -295,8 +288,8 @@ void main() {
             if (KV_bounds_check || d * 16 + 16 > HSK)
 #endif
             {
-                uint coord = (gl_SubgroupID * MatBc) * kshstride;
-                coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
+                uint coord = (gl_SubgroupID * MatBc) * kvsh_stride;
+                coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);
             }
 #if BLOCK_SIZE == 1
             else {
@@ -305,8 +298,8 @@ void main() {
             }
 #endif
             } else {
-                uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
-                coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
+                uint coord = (gl_SubgroupID * MatBc) * kvsh_stride + d * 16 / 4;
+                coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);
             }
 
             coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
@@ -329,7 +322,7 @@ void main() {
             barrier();
         }
 
-        if (MASK_ENABLE) {
+        if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) {
             [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
                 uint32_t c = (idx + tid) / (Br / 4);
                 uint32_t r = (idx + tid) % (Br / 4);
@@ -374,7 +367,7 @@ void main() {
         [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
             const uint d_local = d0 / threads_per_rowgroup;
             [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-                Of[r][d_local] = ACC_TYPE(eMf[r]) * Of[r][d_local];
+                Of[r][d_local] = float16_t(eMf[r]) * Of[r][d_local];
             }
         }
 
@@ -397,19 +390,47 @@ void main() {
             }
         }
 
+        if (SHMEM_STAGING != 0) {
+            [[unroll]] for (uint32_t idx = 0; idx < Bc * HSV_pad / 4; idx += gl_WorkGroupSize.x) {
+                uint32_t d = (idx + tid) % (HSV_pad / 4);
+                uint32_t c = (idx + tid) / (HSV_pad / 4);
+                if (idx + gl_WorkGroupSize.x <= Bc * HSV_pad / 4 || c < Bc) {
+                    f16vec4 V_Tf = f16vec4(0);
+                    if ((!KV_bounds_check || j * Bc + c < KV) && (HSV == HSV_pad || d < HSV / 4)) {
+#if BLOCK_SIZE > 1
+                        uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d;
+                        uint ib = coord / BLOCK_SIZE;
+                        uint iqs = (coord % BLOCK_SIZE);
+                        V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
+#else
+                        V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]);
+#endif
+                    }
+
+                    kvsh[c * kvsh_stride + d] = V_Tf;
+                }
+            }
+        }
+        barrier();
+
         const uint num_hsv_tiles = (HSV + MatBc * row_split - 1) / (MatBc * row_split); // round up
 
         // Each subgroup handles HSV/4 columns
         [[unroll]] for (uint32_t hsv_tile = 0; hsv_tile < num_hsv_tiles; ++hsv_tile) {
             const uint hsv_offset = (hsv_tile * row_split + gl_SubgroupID) * 16;
 
-            SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
+            coopmat<float16_t, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator> PVMat = coopmat<float16_t, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
 
             // Preload V tiles for [Bc, 16 * num subgroups]
             const uint v_rows = Bc;
             const uint v_total = v_rows * v_cols;
             const uint v_loads_per_thread = v_total / gl_WorkGroupSize.x;
 
+            // If SHMEM_STAGING is set, a Bc * HSV_pad size tile of V is loaded to shmem.
+            // If not, f16 V is loaded directly from global memory if aligned, otherwise
+            // staged through a Bc * MatBr size staging buffer.
+            // If V is not type f16, then it is always staged for dequantization.
+            if (SHMEM_STAGING == 0) {
 #if BLOCK_SIZE == 1
             // For f16, only preload if not aligned
             if (KV_bounds_check) {
@@ -428,44 +449,52 @@ void main() {
 
                 if (!KV_bounds_check || (v_row < KV && v_col < HSV)) {
 #if BLOCK_SIZE > 1
-                    ksh[row * vsh_stride + col] = f16vec4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V));
+                    kvsh[row * vsh_stride + col] = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
 #else
-                    ksh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4];
+                    kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4];
 #endif
                 } else {
-                    ksh[row * vsh_stride + col] = f16vec4(0.0f);
+                    kvsh[row * vsh_stride + col] = f16vec4(0.0f);
                 }
             }
+
 #if BLOCK_SIZE == 1
             }
 #endif
-
+            }
             barrier();
 
-            [[unroll]] for (uint32_t bc_chunk = 0; bc_chunk < Bc / MatBc; ++bc_chunk) {
-                coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor);
+            const uint o_offset = gl_SubgroupID * MatBr / 4;
+
+            if (hsv_offset < HSV_pad) {
+                [[unroll]] for (uint32_t bc_chunk = 0; bc_chunk < Bc / MatBc; ++bc_chunk) {
+                    coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor);
 
+                    if (SHMEM_STAGING == 0) {
 #if BLOCK_SIZE == 1
-                if (!KV_bounds_check) {
-                    // F16 values can be loaded directly from global memory
-                    const uint v_tile_row = j * Bc + bc_chunk * MatBc;
-                    const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4;
-                    coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
-                } else
+                    if (!KV_bounds_check) {
+                        // F16 values can be loaded directly from global memory
+                        const uint v_tile_row = j * Bc + bc_chunk * MatBc;
+                        const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4;
+                        coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
+                    } else
 #endif
-                {
-                    const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4);
-                    coopMatLoad(QMat, ksh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor);
+                    {
+                        const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4);
+                        coopMatLoad(QMat, kvsh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor);
+                    }
+                    } else {
+                        const uint v_tile_offset = bc_chunk * MatBc * kvsh_stride + (hsv_tile * row_split + gl_SubgroupID) * (MatBc / 4);
+                        coopMatLoad(QMat, kvsh, v_tile_offset, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor);
+                    }
+
+                    PVMat = coopMatMulAdd(KMat, QMat, PVMat);
                 }
 
-                SfMat = coopMatMulAdd(KMat, QMat, SfMat);
+                // Store PVMat to pvsh and load into Of
+                coopMatStore(PVMat, pvsh, o_offset, osh_stride, gl_CooperativeMatrixLayoutRowMajor);
             }
 
-            // Store SfMat to sfsh and load into Of
-            const uint osh_stride = row_split * MatBc / 4;
-            const uint o_offset = gl_SubgroupID * MatBc / 4;
-            coopMatStore(SfMat, sfsh, o_offset, osh_stride, gl_CooperativeMatrixLayoutRowMajor);
-
             barrier();
 
             const uint hsv_per_tile = row_split * MatBc;
@@ -484,7 +513,7 @@ void main() {
 
                     if (hsv_col >= hsv_base && hsv_col < hsv_base + hsv_per_tile && hsv_col < HSV) {
                         const uint local_hsv = (hsv_col - hsv_base) / 4;
-                        Of[r][d_local] += ACC_TYPEV4(sfsh[row * osh_stride + local_hsv]);
+                        Of[r][d_local] += pvsh[row * osh_stride + local_hsv];
                     }
                 }
             }
@@ -500,27 +529,48 @@ void main() {
     // If there is split_k, then the split_k resolve shader does the final
     // division by L. Store the intermediate O value and per-row m and L values.
     if (p.k_num > 1) {
-        // note: O and Q have swapped coord 1,2.
-        uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
+        if (p.gqa_ratio > 1) {
+            // note: O and Q have swapped coord 1,2.
+            uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)) / 4;
 
-        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-            if (tile_row(r) < N) {
-                [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
-                    const uint d = d0 + col_tid;
-                    if (d >= HSV/4) break;
-                    const uint d_local = d0 / threads_per_rowgroup;
-                    [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
-                        perElemOpGqaStore(tile_row(r), 4 * d + comp, float(Of[r][d_local][comp]), o_offset, iq2, N);
+            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                if (tile_row(r) < N) {
+                    [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
+                        const uint d = d0 + col_tid;
+                        if (d >= HSV/4) break;
+                        const uint d_local = d0 / threads_per_rowgroup;
+                        gqaStore(tile_row(r), d, Of[r][d_local], o_offset, iq2, N);
                     }
                 }
             }
-        }
 
-        o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
-        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-            if (tile_row(r) < N) {
-                perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
-                perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
+            o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
+            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                if (tile_row(r) < N) {
+                    perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N);
+                    perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N);
+                }
+            }
+        } else {
+            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                const uint row = tile_row(r);
+                const uint global_row = i * Br + row;
+
+                if (global_row < N) {
+                    uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)) / 4;
+
+                    [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
+                        const uint d = d0 + col_tid;
+                        if (d >= HSV/4) break;
+                        data_ov4[o_offset + iq2 * HSV/4 + d] = D_TYPEV4(Of[r][d/threads_per_rowgroup]);
+                    }
+                }
+
+                if (global_row < N && col_tid == 0) {
+                    uint32_t lm_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3));
+                    data_o[lm_offset + iq2] = D_TYPE(Lf[r]);
+                    data_o[lm_offset + p.ne1 + iq2] = D_TYPE(Mf[r]);
+                }
             }
         }
 
@@ -539,7 +589,7 @@ void main() {
 
                 [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
                     const uint d_local = d0 / threads_per_rowgroup;
-                    Of[r][d_local] *= ACC_TYPE(ms);
+                    Of[r][d_local] *= float16_t(ms);
                 }
             } else {
                 vs = exp(sink - Mf[r]);
@@ -557,14 +607,14 @@ void main() {
     [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
         const uint d_local = d0 / threads_per_rowgroup;
         [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-            Of[r][d_local] *= ACC_TYPE(Lfrcp[r]);
-#if defined(ACC_TYPE_MAX)
-            Of[r][d_local] = clamp(Of[r][d_local], -ACC_TYPE_MAX, ACC_TYPE_MAX);
+            Of[r][d_local] *= float16_t(Lfrcp[r]);
+#if defined(FLOAT_TYPE_MAX)
+            Of[r][d_local] = clamp(Of[r][d_local], -FLOAT_TYPE_MAX, FLOAT_TYPE_MAX);
 #endif
         }
     }
 
-    uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV;
+    uint32_t o_offset = (gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV) / 4;
 
     if (p.gqa_ratio > 1) {
         [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
@@ -573,9 +623,7 @@ void main() {
                     const uint d = d0 + col_tid;
                     if (d >= HSV / 4) break;
                     const uint d_local = d0 / threads_per_rowgroup;
-                    [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
-                        perElemOpGqaStore(tile_row(r), 4 * d + comp, float(Of[r][d_local][comp]), o_offset, iq2, N);
-                    }
+                    gqaStore(tile_row(r), d, Of[r][d_local], o_offset, iq2, N);
                 }
             }
         }
@@ -586,9 +634,7 @@ void main() {
                     const uint d = d0 + col_tid;
                     if (d >= HSV / 4) break;
                     const uint d_local = d0 / threads_per_rowgroup;
-                    [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
-                        data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4 * d + comp] = D_TYPE(Of[r][d_local][comp]);
-                    }
+                    data_ov4[o_offset + (iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV) / 4 + d] = D_TYPEV4(Of[r][d_local]);
                 }
             }
         }
index 853f17fa16ee136257f44863fd0ee7d6bbb6d75d..0ea181342ceab444fea1a4d6bded33d66ea73a79 100644 (file)
@@ -72,6 +72,28 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
     return elem;
 }
 
+// Store O values for non-GQA split_k. Rows are tokens, not heads.
+D_TYPE perElemOpNonGqaSplitKStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t unused, const in uint32_t iq2, const in uint32_t N) {
+    uint32_t global_row = i * Br + r;
+    if (global_row < N && c < HSV) {
+        uint32_t o_off = HSV * p.ne1
+            * (split_k_index + p.k_num * (global_row + p.ne2 * iq3));
+        data_o[o_off + iq2 * HSV + c] = D_TYPE(elem);
+    }
+    return elem;
+}
+
+// Store L/M values for non-GQA split_k.
+ACC_TYPE perElemOpNonGqaSplitKStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t lm_base, const in uint32_t iq2, const in uint32_t N) {
+    uint32_t global_row = i * Br + r;
+    if (global_row < N && c == 0) {
+        uint32_t lm_off = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3
+            + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3));
+        data_o[lm_off + lm_base + iq2] = D_TYPE(elem);
+    }
+    return elem;
+}
+
 void main() {
 #ifdef NEEDS_INIT_IQ_SHMEM
     init_iq_shmem(gl_WorkGroupSize);
@@ -290,13 +312,19 @@ void main() {
     if (p.k_num > 1) {
         coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
 
-        // note: O and Q have swapped coord 1,2.
-        uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
-        coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
-
-        o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
-        coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
-        coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
+        if (p.gqa_ratio > 1) {
+            // note: O and Q have swapped coord 1,2.
+            uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
+            coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
+
+            o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3));
+            coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
+            coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
+        } else {
+            coopMatPerElementNV(O_D, O_D, perElemOpNonGqaSplitKStore, 0u, iq2, N);
+            coopMatPerElementNV(L, L, perElemOpNonGqaSplitKStoreCol0, 0u, iq2, N);
+            coopMatPerElementNV(M, M, perElemOpNonGqaSplitKStoreCol0, p.ne1, iq2, N);
+        }
         return;
     }
 
index 42ebc21e2a6eb6df303d885bdccfd3c061509df0..85455988c57cb531c5510da8ce61efd7b7839d5d 100644 (file)
@@ -595,8 +595,6 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
 }
 
 void process_shaders() {
-    std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}};
-
     // matmul
     for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) {
         // No coopmats
@@ -622,49 +620,63 @@ void process_shaders() {
         }
     }
 
-    // flash attention
-    for (const auto& f16acc : {false, true}) {
-        std::map<std::string, std::string> fa_base_dict = base_dict;
-        fa_base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
-        fa_base_dict["ACC_TYPEV4"] = f16acc ? "f16vec4" : "vec4";
-        if (f16acc) {
-            fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
+    for (const bool& fp16 : {false, true}) {
+        std::map<std::string, std::string> base_dict;
+        if (fp16) {
+            base_dict = {{"FLOAT_TYPE", "float16_t"}, {"FLOAT_TYPEV4", "f16vec4"}, {"FLOAT16", "1"}, {"FLOAT_TYPE_MAX", "float16_t(65504.0)"}};
+        } else {
+            base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV4", "vec4"}};
         }
 
-        for (const auto& tname : type_names) {
-            if (tname == "bf16") continue;
+        // flash attention
+        for (const bool& f16acc : {false, true}) {
+            std::map<std::string, std::string> fa_base_dict = base_dict;
+            fa_base_dict["ACC_TYPE"] = fp16 && f16acc ? "float16_t" : "float";
+            fa_base_dict["ACC_TYPEV4"] = fp16 && f16acc ? "f16vec4" : "vec4";
+            if (fp16 && f16acc) {
+                fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)";
+            }
+
+            for (const auto& tname : type_names) {
+                if (tname == "bf16") continue;
 
+                if (fp16) {
 #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
-            if (tname == "f16") {
-                string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
-                    merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, true, f16acc);
-            } else {
-                std::string data_a_key = "DATA_A_" + to_uppercase(tname);
-                string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
-                    merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
-            }
+                if (tname == "f16") {
+                    string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
+                        merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, true, f16acc);
+                } else {
+                    std::string data_a_key = "DATA_A_" + to_uppercase(tname);
+                    string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
+                        merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, true, f16acc);
+                }
 #endif
 #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
-            if (tname == "f16") {
-                string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
-                    merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc);
-            } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
-                std::string data_a_key = "DATA_A_" + to_uppercase(tname);
-                string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
-                    merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
-            }
+                if (tname == "f16") {
+                    string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
+                        merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), fp16, true, false, f16acc);
+                } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
+                    std::string data_a_key = "DATA_A_" + to_uppercase(tname);
+                    string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
+                        merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), fp16, true, false, f16acc);
+                }
 #endif
-            if (tname == "f16") {
-                string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
-                    merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc);
-            } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
-                std::string data_a_key = "DATA_A_" + to_uppercase(tname);
-                string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
-                    merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
+                }
+
+                if (tname == "f16") {
+                    string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
+                        merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc);
+                } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") {
+                    std::string data_a_key = "DATA_A_" + to_uppercase(tname);
+                    string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
+                        merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc);
+                }
             }
         }
     }
 
+    std::map<std::string, std::string> base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}};
+
     for (const auto& tname : type_names) {
         // mul mat vec
         std::string data_a_key = "DATA_A_" + to_uppercase(tname);