]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Vulkan Flash Attention Coopmat1 Refactor (#19075)
authorRuben Ortlam <redacted>
Wed, 28 Jan 2026 17:52:45 +0000 (18:52 +0100)
committerGitHub <redacted>
Wed, 28 Jan 2026 17:52:45 +0000 (18:52 +0100)
* vulkan: use coopmat for flash attention p*v matrix multiplication

* fix P loading issue

* fix barrier position

* remove reduction that is no longer needed

* move max thread reduction into loop

* remove osh padding

* add bounds checks and padding

* remove unused code

* fix shmem sizes, loop duration and accesses

* don't overwrite Qf, add new shared psh buffer instead

* add missing bounds checks

* use subgroup reductions

* optimize

* move bounds check, reduce barriers

* support other Bc values and other subgroup sizes

* remove D_split

* replace Of register array with shared memory Ofsh array

* parallelize HSV across the rowgroups

* go back to Of in registers, not shmem

* vectorize sfsh

* don't store entire K tile in shmem

* fixes

* load large k tiles to shmem on Nvidia

* adapt shared memory host check function to shader changes

* remove Bc 32 case

* remove unused variable

* fix missing mask reduction tmspsh barrier

* fix mask bounds check

* fix rowmax f16 under/overflow to inf

* fix flash_attn_cm2 BLOCK_SIZE preprocessor directives

ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp

index 514f290d0980599978a337ce79c472802fe3248d..3852867c29107f8fbd709d03a07382c8bc31b8e5 100644 (file)
@@ -3162,17 +3162,31 @@ static void ggml_vk_load_shaders(vk_device& device) {
         // For scalar, use 128 (arbitrary)
         // The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs.
         const uint32_t D = (hsk|hsv);
-        uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
-                            ? scalar_flash_attention_workgroup_size
-                            : ((small_rows && (D % 32) == 0) ? 256 : 128);
         auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache);
 
+        uint32_t wg_size;
+        switch (path) {
+        case FA_COOPMAT2:
+            wg_size = ((small_rows && (D % 32) == 0) ? 256 : 128);
+            break;
+        case FA_COOPMAT1:
+            wg_size = (rows_cols[1] / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc
+            break;
+        default:
+            wg_size = scalar_flash_attention_workgroup_size;
+            break;
+        }
+
         // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
         // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
         const uint32_t D_lsb = D ^ (D & (D-1));
         uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4);
 
-        return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split};
+        // Nvidia prefers shared memory use to load large tiles of K
+        // AMD prefers loading K directly from global memory
+        const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA ? 1 : 0;
+
+        return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem};
     };
 
 #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
@@ -3187,15 +3201,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
             if (path == FAPATH) { \
                 if (aligned) { \
                     if (f32acc) { \
-                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _data,  "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
+                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _data,  "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0));     \
                     } else { \
-                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
+                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0));     \
                     } \
                 } else { \
                     if (f32acc) { \
-                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc"         #NAMELC, flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _data,  "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1,                                        true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
+                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc"         #NAMELC, flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ##            SUFFIX ## _data,  "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1,                                        true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0));     \
                     } else { \
-                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc"         #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1,                                        true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0));     \
+                        ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc"         #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len,  flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data,  "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1,                                        true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0));     \
                     } \
                 } \
             } \
@@ -8344,41 +8358,49 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
     const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf;
     const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
 
-    VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
+    VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
 
     return supported;
 }
 
-static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc) {
+static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) {
     // Needs to be kept up to date on shader changes
     GGML_UNUSED(hsv);
-    const uint32_t wg_size = scalar_flash_attention_workgroup_size;
-    const uint32_t Br = coopmat1_flash_attention_num_large_rows;
-    const uint32_t Bc = scalar_flash_attention_Bc;
+    const auto rows_cols = fa_rows_cols(FA_COOPMAT1, hsk, hsv, 0, kv_type, false, false);
+    const uint32_t Br = rows_cols[0];
+    const uint32_t Bc = rows_cols[1];
+
+    const uint32_t MatBr = 16, MatBc = 16;
+
+    const uint32_t row_split = Bc / MatBc;
 
     const uint32_t hsk_pad = ROUNDUP_POW2(hsk, 16);
 
     const uint32_t acctype = f32acc ? 4 : 2;
     const uint32_t f16vec4 = 8;
 
-    const uint32_t tmpsh = wg_size * sizeof(float);
-    const uint32_t tmpshv4 = wg_size * 4 * acctype;
+    const uint32_t tmpsh = (Bc / MatBc) * sizeof(float);
 
     const uint32_t qstride = hsk_pad / 4 + 2;
     const uint32_t Qf = Br * qstride * f16vec4;
 
+    const uint32_t psh_stride = Br / 4 + 2;
+    const uint32_t Psh = Bc * psh_stride * f16vec4;
+
     const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
     const uint32_t sfsh = Bc * sfshstride * acctype;
 
-    const uint32_t kshstride = hsk_pad / 4 + 2;
-    const uint32_t ksh = Bc * kshstride * f16vec4;
+    const bool k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA;
+    const uint32_t kshstride = (k_load_shmem ? hsk_pad : MatBr) / 4 + 2;
+    const uint32_t vsh_stride = MatBc / 4 * row_split;
+    const uint32_t ksh = ((kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)) * f16vec4;
 
-    const uint32_t slope = Br * sizeof(float);
+    const uint32_t slope = Br * acctype;
 
-    const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope;
+    const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + slope;
     const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
 
-    VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
+    VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", kv_type=" << kv_type << ", total_size=" << total_size << ", supported=" << supported);
 
     return supported;
 }
@@ -8442,7 +8464,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
         const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
                                              (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
 
-        const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32);
+        const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32, k->type);
 
         if (!coopmat_shape_supported || !coopmat_shmem_supported) {
             path = FA_SCALAR;
index 29b5c7c3a4185ef71022df700a1880f081ee82a8..23a4d2c00586b3be4ef4e23fc4e39d73bc0edc18 100644 (file)
@@ -8,6 +8,8 @@ layout (constant_id = 3) const uint32_t HSK = 32;
 layout (constant_id = 4) const uint32_t HSV = 32;
 layout (constant_id = 5) const uint32_t Clamp = 0;
 layout (constant_id = 6) const uint32_t D_split = 16;
+layout (constant_id = 7) const uint32_t SubGroupSize = 32;
+layout (constant_id = 8) const uint32_t K_LOAD_SHMEM = 0;
 
 // Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths
 const uint32_t HSK_pad = (HSK + 15) & ~15;
@@ -74,6 +76,10 @@ layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16
 layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
 #endif
 
+#ifndef BLOCK_SIZE
+#define BLOCK_SIZE 1
+#endif
+
 #if defined(DATA_A_F32)
 #undef BLOCK_SIZE
 #define BLOCK_SIZE 4
index 0eb50fe58f9d643790f97728c37a2335fb28c873..83d52d19d672b330e762f8ccb5a02c82cb6d350e 100644 (file)
@@ -7,6 +7,7 @@
 #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
 
 #extension GL_KHR_shader_subgroup_basic : enable
+#extension GL_KHR_shader_subgroup_arithmetic : enable
 #extension GL_KHR_shader_subgroup_vote : enable
 #extension GL_KHR_memory_scope_semantics : enable
 #extension GL_KHR_cooperative_matrix : enable
 #include "types.glsl"
 #include "flash_attn_base.glsl"
 
-const uint32_t HSK_per_thread = HSK / D_split;
-const uint32_t HSV_per_thread = HSV / D_split;
+// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
+const uint32_t MatBr = 16;
+const uint32_t MatBc = 16;
 
-const uint32_t row_split = 4;
+const uint32_t row_split = Bc / MatBc;
 const uint32_t rows_per_thread = Br / row_split;
-const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split;
+const uint32_t cols_per_iter = gl_WorkGroupSize.x / row_split;
 const uint32_t cols_per_thread = Bc / cols_per_iter;
 
 
@@ -40,24 +42,24 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
     return elem;
 }
 
-// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd
-const uint32_t MatBr = 16;
-const uint32_t MatBc = 16;
-
-shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x];
-shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x];
+shared float tmpsh[row_split];
 
 const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4
 shared f16vec4 Qf[Br * qstride];
 
+const uint psh_stride = Br / 4 + 2;
+shared f16vec4 Psh[Bc * psh_stride];
+
 // Avoid padding for hsk==256 to make it fit in 48KB shmem.
-const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br;
-shared ACC_TYPE sfsh[Bc * sfshstride];
+const uint32_t sfshstride = (HSK <= 128) ? (Br / 4 + 2) : Br / 4;
+shared ACC_TYPEV4 sfsh[Bc * sfshstride];
 
-const uint32_t kshstride = HSK_pad / 4 + 2; // in units of f16vec4
-shared f16vec4 ksh[Bc * kshstride];
+const uint32_t kshstride = (K_LOAD_SHMEM != 0 ? HSK_pad : MatBr) / 4 + 2; // in units of f16vec4
+const uint v_cols = MatBc / 4 * row_split; // total cols, 4 vec4s per MatBc * number of subgroups
+const uint vsh_stride = v_cols;
+shared f16vec4 ksh[(kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)];
 
-shared float slope[Br];
+shared ACC_TYPE slope[Br];
 
 void main() {
 #ifdef NEEDS_INIT_IQ_SHMEM
@@ -69,9 +71,9 @@ void main() {
     const uint32_t tid = gl_LocalInvocationIndex;
 
     const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split;
+    const uint32_t d_per_thread = (HSV/4 + threads_per_rowgroup - 1) / threads_per_rowgroup;
     const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup;
-    const uint32_t d_tid = gl_LocalInvocationIndex % D_split;
-    const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split;
+    const uint32_t col_tid = gl_LocalInvocationIndex % threads_per_rowgroup;
 
 #define tile_row(r) (row_tid * rows_per_thread + (r))
 
@@ -102,9 +104,9 @@ void main() {
     }
     barrier();
 
-    ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4];
-    [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+    ACC_TYPEV4 Of[rows_per_thread][d_per_thread];
+    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+        [[unroll]] for (uint32_t d = 0; d < d_per_thread; ++d) {
             Of[r][d] = ACC_TYPEV4(0.0);
         }
     }
@@ -125,13 +127,11 @@ void main() {
             uint r = tid;
             slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2);
         }
-        barrier();
     } else {
         if (tid < Br) {
             uint r = tid;
-            slope[r] = 1.0;
+            slope[r] = ACC_TYPE(1.0);
         }
-        barrier();
     }
 
 #if BLOCK_SIZE > 1
@@ -149,19 +149,45 @@ void main() {
     [[dont_unroll]]
     for (uint32_t j = start_j; j < end_j; ++j) {
 
-        float mask_cache[Bc * Br / WorkGroupSize];
+        f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize];
         if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
             bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
 
             float max_mask = NEG_FLT_MAX_OVER_2;
-            [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
-                uint32_t c = (idx + tid) % Bc;
-                uint32_t r = (idx + tid) / Bc;
-                if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
-                    if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
-                        float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
+            [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
+                uint32_t c = (idx + tid) / (Br / 4);
+                uint32_t r = (idx + tid) % (Br / 4);
+                if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) {
+                    if ((!KV_bounds_check || j * Bc + c < KV)) {
+                        f16vec4 m;
+                        if (!nem1_bounds_check || i * Br + r * 4 + 3 < p.nem1) {
+                            m = f16vec4(data_m[m_offset + (i * Br + r * 4    ) * m_stride + (j * Bc + c)],
+                                        data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
+                                        data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],
+                                        data_m[m_offset + (i * Br + r * 4 + 3) * m_stride + (j * Bc + c)]);
+                            max_mask = max(max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2])), float(m[3]));
+                        } else if (i * Br + r * 4 + 2 < p.nem1) {
+                            m = f16vec4(data_m[m_offset + (i * Br + r * 4    ) * m_stride + (j * Bc + c)],
+                                        data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
+                                        data_m[m_offset + (i * Br + r * 4 + 2) * m_stride + (j * Bc + c)],
+                                        0.0);
+                            max_mask = max(max(max(max_mask, float(m[0])), float(m[1])), float(m[2]));
+                        } else if (i * Br + r * 4 + 1 < p.nem1) {
+                            m = f16vec4(data_m[m_offset + (i * Br + r * 4    ) * m_stride + (j * Bc + c)],
+                                        data_m[m_offset + (i * Br + r * 4 + 1) * m_stride + (j * Bc + c)],
+                                        0.0,
+                                        0.0);
+                            max_mask = max(max(max_mask, float(m[0])), float(m[1]));
+                        } else if (i * Br + r * 4 < p.nem1) {
+                            m = f16vec4(data_m[m_offset + (i * Br + r * 4    ) * m_stride + (j * Bc + c)],
+                                        0.0,
+                                        0.0,
+                                        0.0);
+                            max_mask = max(max_mask, float(m[0]));
+                        } else {
+                            m = f16vec4(0.0);
+                        }
                         mask_cache[idx / WorkGroupSize] = m;
-                        max_mask = max(max_mask, m);
                     }
                 }
             }
@@ -180,26 +206,28 @@ void main() {
             }
         }
 
-        [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
-            uint32_t d = (idx + tid) % (HSK / 4);
-            uint32_t c = (idx + tid) / (HSK / 4);
-            if (c < Bc && d < HSK / 4) {
-                f16vec4 K_Tf = f16vec4(0);
-                if (!KV_bounds_check || j * Bc + c < KV) {
+        if (K_LOAD_SHMEM != 0) {
+            [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
+                uint32_t d = (idx + tid) % (HSK / 4);
+                uint32_t c = (idx + tid) / (HSK / 4);
+                if (c < Bc && d < HSK / 4) {
+                    f16vec4 K_Tf = f16vec4(0);
+                    if (!KV_bounds_check || j * Bc + c < KV) {
 #if BLOCK_SIZE > 1
-                    uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
-                    uint ib = coord / BLOCK_SIZE;
-                    uint iqs = (coord % BLOCK_SIZE);
-                    K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
+                        uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d;
+                        uint ib = coord / BLOCK_SIZE;
+                        uint iqs = (coord % BLOCK_SIZE);
+                        K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
 #else
-                    K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
+                        K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]);
 #endif
-                }
+                    }
 
-                ksh[c * kshstride + d] = K_Tf;
+                    ksh[c * kshstride + d] = K_Tf;
+                }
             }
+            barrier();
         }
-        barrier();
 
         // K * Q^T -> S^T: Bc x HSK_pad * HSK_pad x Br -> Bc x Br
         // Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16
@@ -208,11 +236,55 @@ void main() {
         coopmat<float16_t, gl_ScopeSubgroup, MatBc, 16, gl_MatrixUseA> KMat;
         coopmat<float16_t, gl_ScopeSubgroup, 16, MatBr, gl_MatrixUseB> QMat;
 
-        for (uint32_t d = 0; d < HSK_pad / 16; ++d) {
-            coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
+        [[unroll]] for (uint32_t d = 0; d < HSK_pad / 16; ++d) {
+            if (K_LOAD_SHMEM == 0) {
+#if BLOCK_SIZE == 1
+            if (KV_bounds_check || d * 16 + 16 > HSK) {
+#endif
+            barrier();
+            [[unroll]] for (uint32_t idx = 0; idx < Bc * MatBr / 4; idx += gl_WorkGroupSize.x) {
+                uint32_t col_vec = (idx + tid) % (MatBr / 4);
+                uint32_t row = (idx + tid) / (MatBr / 4);
+                if (idx + tid < Bc * MatBr / 4) {
+                    f16vec4 K_Tf = f16vec4(0);
+                    if ((!KV_bounds_check || j * Bc + row < KV) && (HSK == HSK_pad || d * 16 + col_vec * 4 < HSK)) {
+#if BLOCK_SIZE > 1
+                        uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE + d * 16 + col_vec * 4;
+                        uint ib = coord / BLOCK_SIZE;
+                        uint iqs = (coord % BLOCK_SIZE);
+                        K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K));
+#else
+                        K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]);
+#endif
+                    }
 
-            uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
-            coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
+                    ksh[row * kshstride + col_vec] = K_Tf;
+                }
+            }
+            barrier();
+#if BLOCK_SIZE == 1
+            }
+#endif
+
+#if BLOCK_SIZE == 1
+            if (KV_bounds_check || d * 16 + 16 > HSK)
+#endif
+            {
+                uint coord = (gl_SubgroupID * MatBc) * kshstride;
+                coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
+            }
+#if BLOCK_SIZE == 1
+            else {
+                const uint coord = k_offset / 4 + (j * Bc + gl_SubgroupID * MatBc) * k_stride / 4 + d * 16 / 4;
+                coopMatLoad(KMat, data_kv4, coord, k_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
+            }
+#endif
+            } else {
+                uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4;
+                coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor);
+            }
+
+            coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor);
 
             SfMat = coopMatMulAdd(KMat, QMat, SfMat);
         }
@@ -222,26 +294,26 @@ void main() {
         barrier();
 
         if (p.logit_softcap != 0.0f) {
-            [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
-                uint32_t c = (idx + tid) / Br;
-                uint32_t r = (idx + tid) % Br;
-                if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
-                    sfsh[c * sfshstride + r] = ACC_TYPE(p.logit_softcap * tanh(sfsh[c * sfshstride + r]));
+            [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
+                uint32_t c = (idx + tid) / (Br / 4);
+                uint32_t r = (idx + tid) % (Br / 4);
+                if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) {
+                    sfsh[c * sfshstride + r] = ACC_TYPEV4(p.logit_softcap * tanh(sfsh[c * sfshstride + r]));
                 }
             }
             barrier();
         }
 
         if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
-            bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
-
-            [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
-                uint32_t c = (idx + tid) % Bc;
-                uint32_t r = (idx + tid) / Bc;
-                if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
-                    if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
-                        float f = mask_cache[idx / WorkGroupSize];
-                        sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * f);
+            [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) {
+                uint32_t c = (idx + tid) / (Br / 4);
+                uint32_t r = (idx + tid) % (Br / 4);
+                if (idx + tid < Bc * Br / 4 || idx + gl_WorkGroupSize.x <= Bc * Br / 4) {
+                    if (!KV_bounds_check || j * Bc + c < KV) {
+                        // Mask nem1 bounds check is handled when loading masks
+                        ACC_TYPEV4 masks = ACC_TYPEV4(mask_cache[idx / WorkGroupSize]);
+                        ACC_TYPEV4 slopes = ACC_TYPEV4(slope[r * 4], slope[r * 4 + 1], slope[r * 4 + 2], slope[r * 4 + 3]);
+                        sfsh[c * sfshstride + r] += slopes * masks;
                     }
                 }
             }
@@ -250,121 +322,154 @@ void main() {
 
         float eMf[rows_per_thread];
         [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+            const uint r_vec  = tile_row(r) / 4;
+            const uint r_comp = tile_row(r) % 4;
+
             float rowmaxf = NEG_FLT_MAX_OVER_2;
             [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
                 if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
                     continue;
                 }
-                rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride]));
+                rowmaxf = max(rowmaxf, float(sfsh[r_vec + (c * cols_per_iter + col_tid) * sfshstride][r_comp]));
             }
             float Moldf = Mf[r];
 
+            // Compute max across the row
+            rowmaxf = subgroupMax(rowmaxf);
+
             // M = max(rowmax, Mold)
             // P = e^(S - M)
             // eM = e^(Mold - M)
             Mf[r] = max(rowmaxf, Moldf);
             eMf[r] = exp(Moldf - Mf[r]);
+
+            Lf[r] = eMf[r]*Lf[r];
         }
 
-        [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
+        [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
+            const uint d_local = d0 / threads_per_rowgroup;
             [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-                Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d];
+                Of[r][d_local] = ACC_TYPE(eMf[r]) * Of[r][d_local];
             }
         }
-        [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-            Lf[r] = eMf[r]*Lf[r];
-        }
 
+        // Calculate and store Pf in Psh
         [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
-            if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) {
-                continue;
-            }
-            float Pf[rows_per_thread];
-            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-                Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]);
-                Lf[r] += Pf[r];
+            const uint col = c * cols_per_iter + col_tid;
+
+            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; r += 4) {
+                const uint row = tile_row(r);
+                if (KV_bounds_check && j * Bc + col >= KV) {
+                    Psh[col * psh_stride + row / 4] = f16vec4(0.0f);
+                } else {
+                    const vec4 mfvec = vec4(Mf[r], Mf[r + 1], Mf[r + 2], Mf[r + 3]);
+                    const f16vec4 Pf = f16vec4(exp(vec4(sfsh[row / 4 + col * sfshstride]) - mfvec));
+                    [[unroll]] for (uint32_t vec_idx = 0; vec_idx < 4; ++vec_idx) {
+                        Lf[r + vec_idx] += Pf[vec_idx];
+                    }
+                    Psh[col * psh_stride + row / 4] = Pf;
+                }
             }
-            [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
+        }
+
+        const uint num_hsv_tiles = (HSV + MatBc * row_split - 1) / (MatBc * row_split); // round up
+
+        // Each subgroup handles HSV/4 columns
+        [[unroll]] for (uint32_t hsv_tile = 0; hsv_tile < num_hsv_tiles; ++hsv_tile) {
+            const uint hsv_offset = (hsv_tile * row_split + gl_SubgroupID) * 16;
+
+            SfMat = coopmat<ACC_TYPE, gl_ScopeSubgroup, MatBc, MatBr, gl_MatrixUseAccumulator>(0);
+
+            // Preload V tiles for [Bc, 16 * num subgroups]
+            const uint v_rows = Bc;
+            const uint v_total = v_rows * v_cols;
+            const uint v_loads_per_thread = v_total / gl_WorkGroupSize.x;
+
+#if BLOCK_SIZE == 1
+            // For f16, only preload if not aligned
+            if (KV_bounds_check) {
+#endif
+            [[unroll]] for (uint32_t i = 0; i < v_loads_per_thread; ++i) {
+                const uint idx = i * gl_WorkGroupSize.x + tid;
+                const uint row = idx / v_cols;
+                const uint col = idx % v_cols;
+
+                const uint v_row = j * Bc + row;
+                const uint v_col = hsv_tile * MatBc * row_split + col * 4;
+
+                const uint coord = v_row * v_stride * BLOCK_SIZE + v_col;
+                const uint ib = coord / BLOCK_SIZE;
+                const uint iqs = coord % BLOCK_SIZE;
+
+                if (!KV_bounds_check || (v_row < KV && v_col < HSV)) {
 #if BLOCK_SIZE > 1
-                uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid);
-                uint ib = coord / BLOCK_SIZE;
-                uint iqs = (coord % BLOCK_SIZE);
-                vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V);
+                    ksh[row * vsh_stride + col] = f16vec4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V));
 #else
-                vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
+                    ksh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4];
 #endif
-                [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-                    Of[r][d] += ACC_TYPE(Pf[r]) * ACC_TYPEV4(Vf);
+                } else {
+                    ksh[row * vsh_stride + col] = f16vec4(0.0f);
                 }
             }
-        }
+#if BLOCK_SIZE == 1
+            }
+#endif
 
-        barrier();
-    }
+            barrier();
 
-    // prevent race on tmpsh
-    barrier();
+            [[unroll]] for (uint32_t bc_chunk = 0; bc_chunk < Bc / MatBc; ++bc_chunk) {
+                coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor);
 
-    // reduce across threads
+#if BLOCK_SIZE == 1
+                if (!KV_bounds_check) {
+                    // F16 values can be loaded directly from global memory
+                    const uint v_tile_row = j * Bc + bc_chunk * MatBc;
+                    const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4;
+                    coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor);
+                } else
+#endif
+                {
+                    const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4);
+                    coopMatLoad(QMat, ksh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor);
+                }
+
+                SfMat = coopMatMulAdd(KMat, QMat, SfMat);
+            }
+
+            // Store SfMat to sfsh and load into Of
+            const uint osh_stride = row_split * MatBc / 4;
+            const uint o_offset = gl_SubgroupID * MatBc / 4;
+            coopMatStore(SfMat, sfsh, o_offset, osh_stride, gl_CooperativeMatrixLayoutRowMajor);
 
-    float rowmaxf[rows_per_thread], eMf[rows_per_thread], Moldf[rows_per_thread];
-    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-        FLOAT_TYPE M = Mf[r];
-        tmpsh[tid] = M;
-        // Compute max across the row
-        barrier();
-        [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
-            M = max(M, tmpsh[tid ^ s]);
-            barrier();
-            tmpsh[tid] = M;
             barrier();
-        }
-        rowmaxf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup];
-        barrier();
-    }
 
-    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-        Moldf[r] = Mf[r];
+            const uint hsv_per_tile = row_split * MatBc;
+            const uint hsv_base = hsv_tile * hsv_per_tile;
+            const uint d_values_per_tile = hsv_per_tile / 4;
 
-        // M = max(rowmax, Mold)
-        // eM = e^(Mold - M)
-        Mf[r] = max(rowmaxf[r], Moldf[r]);
-        eMf[r] = exp(Moldf[r] - Mf[r]);
+            const uint d_start = hsv_tile * d_values_per_tile;
+            const uint d_end = min(d_start + d_values_per_tile, HSV / 4);
 
-        Lf[r] = eMf[r]*Lf[r];
-    }
+            [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
+                const uint row = tile_row(r);
 
-    [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-        FLOAT_TYPE L = Lf[r];
-        tmpsh[tid] = L;
-        // Compute sum across the row
-        barrier();
-        [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
-            L += tmpsh[tid ^ s];
-            barrier();
-            tmpsh[tid] = L;
-            barrier();
+                [[unroll]] for (uint32_t d_local = 0; d_local < d_per_thread; ++d_local) {
+                    const uint d = d_local * threads_per_rowgroup + col_tid;
+                    const uint hsv_col = 4 * d;
+
+                    if (hsv_col >= hsv_base && hsv_col < hsv_base + hsv_per_tile && hsv_col < HSV) {
+                        const uint local_hsv = (hsv_col - hsv_base) / 4;
+                        Of[r][d_local] += ACC_TYPEV4(sfsh[row * osh_stride + local_hsv]);
+                    }
+                }
+            }
         }
-        Lf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup];
+
         barrier();
     }
 
     [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-        [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-
-            Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d];
-            tmpshv4[tid] = Of[r][d];
-
-            barrier();
-            [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) {
-                Of[r][d] += tmpshv4[tid ^ s];
-                barrier();
-                tmpshv4[tid] = Of[r][d];
-                barrier();
-            }
-            Of[r][d] = tmpshv4[d_tid + row_tid * threads_per_rowgroup];
-            barrier();
-        }
+        Lf[r] = subgroupAdd(Lf[r]);
     }
 
     // If there is split_k, then the split_k resolve shader does the final
@@ -375,9 +480,12 @@ void main() {
 
         [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
             if (tile_row(r) < N) {
-                [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
+                [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
+                    const uint d = d0 + col_tid;
+                    if (d >= HSV/4) break;
+                    const uint d_local = d0 / threads_per_rowgroup;
                     [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
-                        perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
+                        perElemOpGqaStore(tile_row(r), 4 * d + comp, float(Of[r][d_local][comp]), o_offset, iq2, N);
                     }
                 }
             }
@@ -404,8 +512,9 @@ void main() {
             if (sink > Mf[r]) {
                 ms = exp(Mf[r] - sink);
 
-                [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
-                    Of[r][d] *= ACC_TYPE(ms);
+                [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
+                    const uint d_local = d0 / threads_per_rowgroup;
+                    Of[r][d_local] *= ACC_TYPE(ms);
                 }
             } else {
                 vs = exp(sink - Mf[r]);
@@ -420,11 +529,12 @@ void main() {
         Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]);
     }
 
-    [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
+    [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
+        const uint d_local = d0 / threads_per_rowgroup;
         [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
-            Of[r][d] *= ACC_TYPE(Lfrcp[r]);
+            Of[r][d_local] *= ACC_TYPE(Lfrcp[r]);
 #if defined(ACC_TYPE_MAX)
-            Of[r][d] = clamp(Of[r][d], -ACC_TYPE_MAX, ACC_TYPE_MAX);
+            Of[r][d_local] = clamp(Of[r][d_local], -ACC_TYPE_MAX, ACC_TYPE_MAX);
 #endif
         }
     }
@@ -434,9 +544,12 @@ void main() {
     if (p.gqa_ratio > 1) {
         [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
             if (tile_row(r) < N) {
-                [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
+                [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
+                    const uint d = d0 + col_tid;
+                    if (d >= HSV / 4) break;
+                    const uint d_local = d0 / threads_per_rowgroup;
                     [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
-                        perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N);
+                        perElemOpGqaStore(tile_row(r), 4 * d + comp, float(Of[r][d_local][comp]), o_offset, iq2, N);
                     }
                 }
             }
@@ -444,9 +557,12 @@ void main() {
     } else {
         [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
             if (i * Br + tile_row(r) < N) {
-                [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
+                [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) {
+                    const uint d = d0 + col_tid;
+                    if (d >= HSV / 4) break;
+                    const uint d_local = d0 / threads_per_rowgroup;
                     [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) {
-                        data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]);
+                        data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4 * d + comp] = D_TYPE(Of[r][d_local][comp]);
                     }
                 }
             }
index d49a8da65fbee5e0f74d85e19a9519f42b746b15..54f1b0b62267e4a6e061320f41f974a99025c2f7 100644 (file)
@@ -55,7 +55,7 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
     return max(elem0, elem1);
 }
 
-#if defined(BLOCK_SIZE)
+#if BLOCK_SIZE > 1
 #define DECODEFUNC , DEQUANTFUNC
 #else
 #define DECODEFUNC
@@ -85,7 +85,7 @@ void main() {
 
     tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
 
-#if defined(BLOCK_SIZE)
+#if BLOCK_SIZE > 1
     tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE);
     tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);
 #endif
@@ -98,7 +98,7 @@ void main() {
     if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
     {
         q_stride &= ~7;
-#if !defined(BLOCK_SIZE)
+#if BLOCK_SIZE == 1
         k_stride &= ~7;
         v_stride &= ~7;
 #endif