#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_KHR_shader_subgroup_shuffle : enable
+#extension GL_KHR_shader_subgroup_vote : enable
#include "types.glsl"
#include "flash_attn_base.glsl"
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
+ if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
+ bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
+
+ float max_mask = NEG_FLT_MAX_OVER_2;
+ [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
+ uint32_t c = (idx + tid) % Bc;
+ uint32_t r = (idx + tid) / Bc;
+ if (idx + tid < Bc * Br) {
+ if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
+ float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
+ masksh[c][r] = m;
+ max_mask = max(max_mask, m);
+ } else {
+ masksh[c][r] = float(0);
+ }
+ }
+ }
+ // skip the block if the mask is entirely -inf
+ bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
+ barrier();
+ if (gl_SubgroupInvocationID == 0) {
+ tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
+ }
+ barrier();
+ [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
+ max_mask = max(max_mask, tmpsh[s]);
+ }
+ if (max_mask <= NEG_FLT_MAX_OVER_2) {
+ continue;
+ }
+ }
+
float Sf[Br][cols_per_thread];
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
}
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
- bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
-
- [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
- uint32_t c = (idx + tid) % Bc;
- uint32_t r = (idx + tid) / Bc;
- if (idx + tid < Bc * Br) {
- if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
- masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
- } else {
- masksh[c][r] = float(0);
- }
- }
- }
- barrier();
-
[[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
float mvf = masksh[c * cols_per_iter + col_tid][r];
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_KHR_shader_subgroup_basic : enable
+#extension GL_KHR_shader_subgroup_vote : enable
#extension GL_KHR_memory_scope_semantics : enable
#extension GL_KHR_cooperative_matrix : enable
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
+ float mask_cache[Bc * Br / WorkGroupSize];
+ if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
+ bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
+
+ float max_mask = NEG_FLT_MAX_OVER_2;
+ [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
+ uint32_t c = (idx + tid) % Bc;
+ uint32_t r = (idx + tid) / Bc;
+ if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
+ if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
+ float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
+ mask_cache[idx / WorkGroupSize] = m;
+ max_mask = max(max_mask, m);
+ }
+ }
+ }
+ // skip the block if the mask is entirely -inf
+ bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2);
+ barrier();
+ if (gl_SubgroupInvocationID == 0) {
+ tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f;
+ }
+ barrier();
+ [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
+ max_mask = max(max_mask, tmpsh[s]);
+ }
+ if (max_mask <= NEG_FLT_MAX_OVER_2) {
+ continue;
+ }
+ }
+
[[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) {
uint32_t d = (idx + tid) % (HSK / 4);
uint32_t c = (idx + tid) / (HSK / 4);
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) {
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
- sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]));
+ float f = mask_cache[idx / WorkGroupSize];
+ sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * f);
}
}
}
return max(x, y);
}
+float16_t maxReduceFp16(const in float16_t x, const in float16_t y) {
+ return max(x, y);
+}
+
ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
return x;
}
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
- coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
-
- coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
-
- uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
- coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
- S = coopMatMulAdd(Qf16, K_T, S);
-
- if (p.logit_softcap != 0.0f) {
- [[unroll]]
- for (int k = 0; k < S.length(); ++k) {
- S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]);
- }
- }
-
+ coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
+ tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t
- coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
+ coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv, mvmax;
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
- S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
+ // skip the block if the mask is entirely -inf
+ coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
+ if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
+ continue;
+ }
} else {
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
// Don't clamp against nem1 when GQA is enabled
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
- coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
+ coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax;
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
- S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
+ // skip the block if the mask is entirely -inf
+ coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
+ if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
+ continue;
+ }
}
}
+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
+
+ coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
+
+ uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
+ coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
+ S = coopMatMulAdd(Qf16, K_T, S);
+
+ if (p.logit_softcap != 0.0f) {
+ [[unroll]]
+ for (int k = 0; k < S.length(); ++k) {
+ S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]);
+ }
+ }
+
+ if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
+ S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
+ }
+
// Clear padding elements to -inf, so they don't contribute to rowmax
if (Clamp != 0 &&
((j + 1) * Bc > KV ||