(K_h2 + int64_t(kb0)*nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup);
}
- for (; kb0 < kb0_stop-1; ++kb0) {
- constexpr bool last_iter = false;
- constexpr bool oob_check = false;
- constexpr int k_VKQ_sup = nbatch_fa;
- flash_attn_ext_f16_iter
- <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
- T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
- (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
- ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
- KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
- }
// kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
if constexpr (ncols2 == 1) {
- if (ne11 % nbatch_fa == 0) {
- constexpr bool last_iter = true;
- constexpr bool oob_check = false;
+ constexpr bool oob_check = true;
+ for (; kb0 < kb0_stop-1; ++kb0) {
+ constexpr bool last_iter = false;
constexpr int k_VKQ_sup = nbatch_fa;
flash_attn_ext_f16_iter
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
- } else {
- constexpr bool last_iter = true;
- constexpr bool oob_check = true;
- const int k_VKQ_sup = ne11 - kb0*nbatch_fa;
+ }
+ constexpr bool last_iter = true;
+ const int k_VKQ_sup = ne11 - kb0*nbatch_fa;
+ flash_attn_ext_f16_iter
+ <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
+ T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
+ (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
+ ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
+ KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
+ } else {
+ constexpr bool oob_check = false;
+ for (; kb0 < kb0_stop-1; ++kb0) {
+ constexpr bool last_iter = false;
+ constexpr int k_VKQ_sup = nbatch_fa;
flash_attn_ext_f16_iter
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
}
- } else {
constexpr bool last_iter = true;
- constexpr bool oob_check = false;
constexpr int k_VKQ_sup = nbatch_fa;
flash_attn_ext_f16_iter
<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,