template<int D, int ncols1, int ncols2> // D == head size
__launch_bounds__(D, 1)
-static __global__ void flash_attn_stream_k_fixup(
- float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03,
- const int ne11, const int ne12, const int nbatch_fa) {
+static __global__ void flash_attn_stream_k_fixup_uniform(
+ float * __restrict__ dst,
+ const float2 * __restrict__ dst_fixup,
+ const int ne01, const int ne02,
+ const int ne12, const int nblocks_stream_k,
+ const int gqa_ratio,
+ const int blocks_per_tile,
+ const uint3 fd_iter_j_z_ne12,
+ const uint3 fd_iter_j_z,
+ const uint3 fd_iter_j) {
+ constexpr int ncols = ncols1*ncols2;
+
+ const int tile_idx = blockIdx.x; // One block per output tile.
+ const int j = blockIdx.y;
+ const int c = blockIdx.z;
+ const int jc = j*ncols2 + c;
+ const int tid = threadIdx.x;
+
+ // nblocks_stream_k is a multiple of ntiles_dst (== gridDim.x), so each tile gets the same number of blocks.
+ const int b_first = tile_idx * blocks_per_tile;
+ const int b_last = b_first + blocks_per_tile - 1;
+
+ const float * dst_fixup_data = ((const float *) dst_fixup) + nblocks_stream_k*(2*2*ncols);
+
+ // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
+ const uint2 dm0 = fast_div_modulo(tile_idx, fd_iter_j_z_ne12);
+ const uint2 dm1 = fast_div_modulo(dm0.y, fd_iter_j_z);
+ const uint2 dm2 = fast_div_modulo(dm1.y, fd_iter_j);
+
+ const int sequence = dm0.x;
+ const int z_KV = dm1.x;
+ const int zt_gqa = dm2.x;
+ const int jt = dm2.y;
+
+ const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
+
+ if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) {
+ return;
+ }
+
+ dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid;
+
+ // Load the partial result that needs a fixup
+ float dst_val = *dst;
+ float max_val;
+ float rowsum;
+ {
+ const float2 tmp = dst_fixup[b_last*ncols + jc];
+ max_val = tmp.x;
+ rowsum = tmp.y;
+ }
+
+ // Combine with all previous blocks in this tile.
+ for (int bidx = b_last - 1; bidx >= b_first; --bidx) {
+ const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid];
+
+ const float2 tmp = dst_fixup[(nblocks_stream_k + bidx)*ncols + jc];
+
+ const float max_val_new = fmaxf(max_val, tmp.x);
+
+ const float diff_val = max_val - max_val_new;
+ const float diff_add = tmp.x - max_val_new;
+
+ const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
+ const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
+
+ dst_val = scale_val*dst_val + scale_add*dst_add;
+ rowsum = scale_val*rowsum + scale_add*tmp.y;
+
+ max_val = max_val_new;
+ }
+
+ // Write back final result:
+ *dst = dst_val / rowsum;
+}
+
+// General fixup kernel for the case where the number of blocks per tile is not uniform across tiles
+// (blocks_num.x not a multiple of ntiles_dst)
+template <int D, int ncols1, int ncols2> // D == head size
+__launch_bounds__(D, 1)
+static __global__ void flash_attn_stream_k_fixup_general(
+ float * __restrict__ dst,
+ const float2 * __restrict__ dst_fixup,
+ const int ne01, const int ne02,
+ const int gqa_ratio,
+ const int total_work,
+ const uint3 fd_iter_k_j_z_ne12,
+ const uint3 fd_iter_k_j_z,
+ const uint3 fd_iter_k_j,
+ const uint3 fd_iter_k) {
constexpr int ncols = ncols1*ncols2;
const int bidx0 = blockIdx.x;
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
- const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
-
- const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
- const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
- const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2;
-
- const int kbc0 = int64_t(bidx0 + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
- const int kbc0_stop = int64_t(bidx0 + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
+ const int kbc0 = int64_t(bidx0 + 0)*total_work / gridDim.x;
+ const int kbc0_stop = int64_t(bidx0 + 1)*total_work / gridDim.x;
const bool did_not_have_any_data = kbc0 == kbc0_stop;
- const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
- const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0;
+ const bool wrote_beginning_of_tile = fastmodulo(kbc0, fd_iter_k) == 0;
+ const bool did_not_write_last = fastdiv(kbc0, fd_iter_k) == fastdiv(kbc0_stop, fd_iter_k) && fastmodulo(kbc0_stop, fd_iter_k) != 0;
if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
return;
}
// z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
- const int sequence = kbc0 /(iter_k*iter_j*iter_z_gqa*ne12);
- const int z_KV = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
- const int zt_gqa = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
- const int jt = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
+ const uint2 dm0 = fast_div_modulo(kbc0, fd_iter_k_j_z_ne12);
+ const uint2 dm1 = fast_div_modulo(dm0.y, fd_iter_k_j_z);
+ const uint2 dm2 = fast_div_modulo(dm1.y, fd_iter_k_j);
+ const uint2 dm3 = fast_div_modulo(dm2.y, fd_iter_k);
+
+ const int sequence = dm0.x;
+ const int z_KV = dm1.x;
+ const int zt_gqa = dm2.x;
+ const int jt = dm3.x;
const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
// Iterate over previous blocks and compute the combined results.
// All CUDA blocks that get here must have a previous block that needs a fixup.
+ const int tile_kbc0 = fastdiv(kbc0, fd_iter_k);
int bidx = bidx0 - 1;
int kbc_stop = kbc0;
while(true) {
- const int kbc = int64_t(bidx)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
+ const int kbc = int64_t(bidx)*total_work / gridDim.x;
if (kbc == kbc_stop) { // Did not have any data.
bidx--;
kbc_stop = kbc;
max_val = max_val_new;
// If this block started in a previous tile we are done and don't need to combine additional partial results.
- if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
+ if (fastmodulo(kbc, fd_iter_k) == 0 || fastdiv(kbc, fd_iter_k) < tile_kbc0) {
break;
}
bidx--;
const int tiles_nwaves = (ntiles_dst + max_blocks - 1) / max_blocks;
const int tiles_efficiency_percent = 100 * ntiles_dst / (max_blocks*tiles_nwaves);
- const int nblocks_stream_k = std::min(max_blocks, ntiles_KV*ntiles_dst);
-
const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || amd_wmma_available(cc) || tiles_efficiency_percent < 75;
- blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_dst;
+ blocks_num.x = ntiles_dst;
blocks_num.y = 1;
blocks_num.z = 1;
+ if(use_stream_k) {
+ const int nblocks_stream_k_raw = std::min(max_blocks, ntiles_KV*ntiles_dst);
+ // Round down to a multiple of ntiles_dst so that each output tile gets the same number of blocks (avoids fixup).
+ // Only do this if the occupancy loss from rounding is acceptable.
+ const int nblocks_stream_k_rounded = (nblocks_stream_k_raw / ntiles_dst) * ntiles_dst;
+ const int max_efficiency_loss_percent = 5;
+ const int efficiency_loss_percent = nblocks_stream_k_rounded > 0
+ ? 100 * (nblocks_stream_k_raw - nblocks_stream_k_rounded) / nblocks_stream_k_raw
+ : 100;
+ const int nblocks_stream_k = efficiency_loss_percent <= max_efficiency_loss_percent
+ ? nblocks_stream_k_rounded
+ : nblocks_stream_k_raw;
+
+ blocks_num.x = nblocks_stream_k;
+ }
+
if (ntiles_dst % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2)));
}
CUDA_CHECK(cudaGetLastError());
if (stream_k) {
- if (ntiles_dst % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
+ if ((int)blocks_num.x % ntiles_dst == 0 && (int)blocks_num.x > ntiles_dst) {
+ // Optimized fixup: nblocks_stream_k is a multiple of ntiles_dst, launch one block per tile.
+ const int nblocks_sk = (int)blocks_num.x;
+ const int bpt = nblocks_sk / ntiles_dst;
+
+ const uint3 fd0 = init_fastdiv_values(ntiles_x * ntiles_z_gqa * K->ne[2]);
+ const uint3 fd1 = init_fastdiv_values(ntiles_x * ntiles_z_gqa);
+ const uint3 fd2 = init_fastdiv_values(ntiles_x);
+
+ const dim3 block_dim_combine(DV, 1, 1);
+ const dim3 blocks_num_combine = {(unsigned)ntiles_dst, ncols1, ncols2};
+
+ flash_attn_stream_k_fixup_uniform<DV, ncols1, ncols2>
+ <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
+ ((float *) KQV->data, dst_tmp_meta.ptr,
+ Q->ne[1], Q->ne[2], K->ne[2], nblocks_sk,
+ gqa_ratio, bpt, fd0, fd1, fd2);
+ } else if (ntiles_dst % blocks_num.x != 0) {
+ // General fixup for the cases where nblocks_stream_k < ntiles_dst.
+ const int total_work = ntiles_KV * ntiles_dst;
+
+ const uint3 fd_k_j_z_ne12 = init_fastdiv_values(ntiles_KV * ntiles_x * ntiles_z_gqa * K->ne[2]);
+ const uint3 fd_k_j_z = init_fastdiv_values(ntiles_KV * ntiles_x * ntiles_z_gqa);
+ const uint3 fd_k_j = init_fastdiv_values(ntiles_KV * ntiles_x);
+ const uint3 fd_k = init_fastdiv_values(ntiles_KV);
+
const dim3 block_dim_combine(DV, 1, 1);
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
- flash_attn_stream_k_fixup<DV, ncols1, ncols2>
+ flash_attn_stream_k_fixup_general<DV, ncols1, ncols2>
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
- ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1], K->ne[2], nbatch_fa);
+ ((float *) KQV->data, dst_tmp_meta.ptr,
+ Q->ne[1], Q->ne[2], gqa_ratio, total_work,
+ fd_k_j_z_ne12, fd_k_j_z, fd_k_j, fd_k);
}
} else if (parallel_blocks > 1) {
const dim3 block_dim_combine(DV, 1, 1);