template<int D, int ncols1, int ncols2> // D == head size
__launch_bounds__(D, 1)
static __global__ void flash_attn_stream_k_fixup(
- float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11,
- const int nbatch_fa) {
+ float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03,
+ const int ne11, const int ne12, const int nbatch_fa) {
constexpr int ncols = ncols1*ncols2;
const int bidx0 = blockIdx.x;
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
- const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
- const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
- const int iter_z = (ne02 + (ncols2 - 1)) / ncols2;
+ const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
- const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*iter_z*ne03) / gridDim.x;
- const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*iter_z*ne03) / gridDim.x;
+ const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
+ const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
+ const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2;
+
+ const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
+ const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
const bool did_not_have_any_data = kbc0 == kbc0_stop;
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
return;
}
- const int sequence = kbc0 / (iter_k*iter_j*iter_z);
- const int zt = (kbc0 - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j);
- const int jt = (kbc0 - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
+ // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
+ const int sequence = kbc0 /(iter_k*iter_j*iter_z_gqa*ne12);
+ const int z_KV = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
+ const int zt_gqa = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
+ const int jt = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
+
+ const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
- if (jt*ncols1 + j >= ne01 || zt*ncols2 + c >= ne02) {
+ if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) {
return;
}
- dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt*(ncols2*D) + (j*ne02 + c)*D + tid;
+ dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid;
// Load the partial result that needs a fixup:
float dst_val = 0.0f;
int bidx = bidx0 - 1;
int kbc_stop = kbc0;
while(true) {
- const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z*ne03) / gridDim.x;
+ const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
if (kbc == kbc_stop) { // Did not have any data.
bidx--;
kbc_stop = kbc;
}
}
- const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
- const int ntiles_z = ((Q->ne[2] + ncols2 - 1) / ncols2);
- const int ntiles_total = ntiles_x * ntiles_z * Q->ne[3];
+ const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
+ const int gqa_ratio = Q->ne[2] / K->ne[2];
+ const int ntiles_z_gqa = ((gqa_ratio + ncols2 - 1) / ncols2);
+ const int ntiles_total = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3];
// Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
// Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
blocks_num.x = ntiles_x;
blocks_num.y = parallel_blocks;
- blocks_num.z = ntiles_z*Q->ne[3];
+ blocks_num.z = ntiles_z_gqa*K->ne[2]*Q->ne[3];
if (parallel_blocks > 1) {
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
- ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], nbatch_fa);
+ ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], K->ne[2], nbatch_fa);
}
} else if (parallel_blocks > 1) {
const dim3 block_dim_combine(DV, 1, 1);
const float logit_softcap,
const uint3 ne01,
const int ne02,
+ const int gqa_ratio,
const int ne11,
const int stride_Q1,
const int stride_Q2,
const int stride_V,
const int stride_mask,
const int jt,
- const int zt,
+ const int zt_gqa,
const int kb0_start,
const int kb0_stop) {
#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
const int j = jc / ncols2;
const int c = jc % ncols2;
- if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt*ncols2 + c < ne02)) {
+ if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt_gqa*ncols2 + c < gqa_ratio)) {
#pragma unroll
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
const int j_dst = jc_dst / ncols2;
const int c_dst = jc_dst % ncols2;
- if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= int(ne01.z)) || (ncols2 > 1 && zt*ncols2 + c_dst >= ne02))) {
+ if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= int(ne01.z)) || (ncols2 > 1 && zt_gqa*ncols2 + c_dst >= gqa_ratio))) {
continue;
}
}
#else
GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dstk_fixup,
- scale, slope, logit_softcap, ne01, ne02,
+ scale, slope, logit_softcap, ne01, ne02, gqa_ratio,
stride_Q1, stride_Q2, stride_K, stride_V, stride_mask,
jt, kb0_start, kb0_stop);
NO_DEVICE_CODE;
const int stride_V = V_is_K_view ? stride_K : nb21 / sizeof(half2);
- const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
- const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1;
- const int iter_z = (ne02 + (ncols2 - 1)) / ncols2;
+ const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
+ const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1;
+ const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2;
// kbc == k block continuous, current index in continuous ijk space.
- int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*iter_z*ne03) / gridDim.x;
- const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*iter_z*ne03) / gridDim.x;
+ int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
+ const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
while (kbc < kbc_stop && kb0_stop == iter_k) {
- const int sequence = kbc / (iter_k*iter_j*iter_z);
- const int zt = (kbc - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j); // head in units of ncols2
- const int jt = (kbc - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
+ // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
+ const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12);
+ const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
+ const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
+ const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
- const int head0 = zt * ncols2;
+ const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
- const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
- const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
+ const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
+ const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV);
const half * mask_h = ncols2 == 1 && !mask ? nullptr :
(const half *) (mask + nb33*(sequence % ne33));
- float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
+ float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2);
- const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
- const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
+ const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV);
+ const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;
- const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
+ const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;
if (KV_max) {
kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
- ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start, kb0_stop);
+ ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
} else {
constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile.
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
- ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start, kb0_stop);
+ ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
}
kbc += iter_k;
return;
}
- const int sequence = kbc / (iter_k*iter_j*iter_z);
- const int zt = (kbc - iter_k*iter_j*iter_z*sequence) / (iter_k*iter_j); // head in units of ncols2
- const int jt = (kbc - iter_k*iter_j*iter_z*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
+ // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index.
+ const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12);
+ const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
+ const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
+ const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
- const int head0 = zt * ncols2;
+ const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
- const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
- const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
+ const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
+ const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV);
const half * mask_h = ncols2 == 1 && !mask ? nullptr :
(const half *) (mask + nb33*(sequence % ne33));
- float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
+ float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2);
- const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
- const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
+ const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV);
+ const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;
- const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
+ const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;
if (KV_max) {
kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
constexpr bool needs_fixup = false;
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
- ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt, kb0_start, kb0_stop);
+ ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
#else
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
max_bias, m0, m1, n_head_log2, logit_softcap,