static constexpr int nwarps_max = 4;
static constexpr bool Q_in_reg = true;
static constexpr int nstages_target = 2;
- static constexpr int nbatch_K2 = 32;
- static constexpr int nbatch_V2 = 32;
- static constexpr int nbatch_combine = 32;
+
+ static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
+ return 32;
+ }
+
+ static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
+ return 32;
+ }
+
+ static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
+ return 32;
+ }
+
+ static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
+ return 32;
+ }
+
+ static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
+ return 32;
+ }
+
+ static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
+ return 32;
+ }
};
template <>
static constexpr int nwarps_max = 4;
static constexpr bool Q_in_reg = true;
static constexpr int nstages_target = 2;
- static constexpr int nbatch_K2 = 40;
- static constexpr int nbatch_V2 = 40;
- static constexpr int nbatch_combine = 40;
+
+ static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
+ return 40;
+ }
+
+ static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
+ return 40;
+ }
+
+ static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
+ return 40;
+ }
+
+ static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
+ return 40;
+ }
+
+ static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
+ return 40;
+ }
+
+ static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
+ return 40;
+ }
};
template <>
static constexpr int nwarps_max = 4;
static constexpr bool Q_in_reg = true;
static constexpr int nstages_target = 2;
- static constexpr int nbatch_K2 = 48;
- static constexpr int nbatch_V2 = 48;
- static constexpr int nbatch_combine = 48;
+
+ static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
+ return 48;
+ }
+
+ static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
+ return 48;
+ }
+
+ static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
+ return 48;
+ }
+
+ static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
+ return 48;
+ }
+
+ static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
+ return 48;
+ }
+
+ static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
+ return 48;
+ }
};
template <>
static constexpr int nwarps_max = 4;
static constexpr bool Q_in_reg = true;
static constexpr int nstages_target = 2;
- static constexpr int nbatch_K2 = 56;
- static constexpr int nbatch_V2 = 56;
- static constexpr int nbatch_combine = 56;
+
+ static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
+ return 56;
+ }
+
+ static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
+ return 56;
+ }
+
+ static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
+ return 56;
+ }
+
+ static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
+ return 56;
+ }
+
+ static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
+ return 56;
+ }
+
+ static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
+ return 56;
+ }
};
template <>
static constexpr int nwarps_max = 4;
static constexpr bool Q_in_reg = true;
static constexpr int nstages_target = 2;
- static constexpr int nbatch_K2 = 64;
- static constexpr int nbatch_V2 = 64;
- static constexpr int nbatch_combine = 64;
+
+ static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
+ return 64;
+ }
+
+ static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
+ return 64;
+ }
+
+ static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
+ return 64;
+ }
+
+ static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
+ return 64;
+ }
+
+ static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
+ return 64;
+ }
+
+ static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
+ return 64;
+ }
};
template <>
static constexpr int nwarps_max = 4;
static constexpr bool Q_in_reg = true;
static constexpr int nstages_target = 2;
- static constexpr int nbatch_K2 = 128;
- static constexpr int nbatch_V2 = 128;
- static constexpr int nbatch_combine = 128;
+
+ static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
+ return 128;
+ }
+
+ static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
+ return 128;
+ }
+
+ static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
+ return 128;
+ }
+
+ static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
+ return 128;
+ }
+
+ static int get_nbatch_combine_host(const int cc, const int ncols) {
+ if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
+ return ncols <= 16 ? 128 : 64;
+ }
+ return 64;
+ }
+
+ static constexpr __device__ int get_nbatch_combine_device(int ncols) {
+#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
+ return ncols <= 16 ? 128 : 64;
+#else
+ GGML_UNUSED(ncols);
+ return 128;
+#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
+ }
};
template <>
static constexpr int nwarps_max = 8;
static constexpr bool Q_in_reg = false;
static constexpr int nstages_target = 1;
- static constexpr int nbatch_K2 = 160;
- static constexpr int nbatch_V2 = 128;
- static constexpr int nbatch_combine = 128;
+
+ static int get_nbatch_K2_host(const int cc, const int ncols) {
+ if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
+ return ncols <= 16 ? 96 : 160;
+ }
+ return ncols <= 16 ? 288 : 160;
+ }
+
+ static constexpr __device__ int get_nbatch_K2_device(int ncols) {
+#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
+ return ncols <= 16 ? 96 : 160;
+#else
+ return ncols <= 16 ? 288 : 160;
+#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
+ }
+
+ static int get_nbatch_V2_host(const int cc, const int ncols) {
+ if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
+ return ncols <= 16 ? 64 : 128;
+ }
+ return ncols <= 16 ? 256 : 128;
+ }
+
+ static constexpr __device__ int get_nbatch_V2_device(int ncols) {
+#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
+ return ncols <= 16 ? 64 : 128;
+#else
+ return ncols <= 16 ? 256 : 128;
+#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
+ }
+
+ static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
+ return 128;
+ }
+
+ static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
+ return 128;
+ }
};
// ------------------------------------------------------------------------------------------------------------------
const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
- auto load = [&] __device__ (const int n) {
+ auto load = [&] __device__ (auto n) {
const int stride_k = WARP_SIZE >> n;
const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
}
}
-template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup, bool last_iter>
+template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter>
static __device__ __forceinline__ void flash_attn_ext_f16_iter(
const float2 * const __restrict__ Q_f2,
const half2 * const __restrict__ K_h2,
constexpr int cols_per_warp = ntiles * tile_B::I;
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
+ constexpr int ncols = ncols1 * ncols2;
+ constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
+ constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
- constexpr int stride_tile_Q = DKQ/2 + 4;
- constexpr int stride_tile_K = c::nbatch_K2 + 4;
- constexpr int stride_tile_V = c::nbatch_V2 + 4;
+ constexpr int stride_tile_Q = DKQ/2 + 4;
+ constexpr int stride_tile_K = nbatch_K2 + 4;
+
+ static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
+ constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
const int k_VKQ_0 = kb0 * c::nbatch_fa;
tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles];
tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C;
if constexpr (nstages > 1) {
- static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
+ static_assert(!mla, "multi-stage loading not implemented for MLA");
+ static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
constexpr bool use_cp_async = true;
cp_async_wait_all();
__syncthreads();
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
- (V_h2 + k_VKQ_0*stride_V, tile_V, c::nbatch_V2, stride_V);
+ (V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
} else {
constexpr bool use_cp_async = nstages == 1;
if (ncols2 > 1 || mask_h2) {
}
#pragma unroll
- for (int k0_start = 0; k0_start < DKQ/2; k0_start += c::nbatch_K2) {
- const int k0_stop = k0_start + c::nbatch_K2 < DKQ/2 ? k0_start + c::nbatch_K2 : DKQ/2;
+ for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) {
+ const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
const int k0_diff = k0_stop - k0_start;
if (nstages <= 1) {
(mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
}
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
- (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, c::nbatch_K2, stride_K);
+ (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
}
}
+
+ // For MLA K and V have the same data.
+ // Therefore, iterate over V in reverse and re-use the data if possible.
+ static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
+ constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
#pragma unroll
- for (int i0_start = 0; i0_start < DV; i0_start += 2*c::nbatch_V2) {
- const int i0_stop = i0_start + 2*c::nbatch_V2 < DV ? i0_start + 2*c::nbatch_V2 : DV;
- const int i0_diff = i0_stop - i0_start;
+ for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) {
+ const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0;
+ const int i0_diff = i0_stop - i0_start;
- if (nstages <= 1) {
+ if (nstages <= 1 && i0_start < reusable_cutoff) {
constexpr bool use_cp_async = nstages == 1;
flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
(V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
}
__syncthreads();
}
+ const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2;
// Calculate VKQ tile:
#pragma unroll
const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
tile_A A;
- load_ldmatrix_trans(A, tile_V + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
+ load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
if (ntiles == 1) {
mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
} else {
#endif // NEW_MMA_AVAILABLE
}
-template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
+template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const float2 * const __restrict__ Q_f2,
const half2 * const __restrict__ K_h2,
constexpr int cols_per_warp = ntiles * tile_B::I;
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
+ constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
+ constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
- constexpr int stride_tile_Q = DKQ/2 + 4;
- constexpr int stride_tile_K = c::nbatch_K2 + 4;
- constexpr int stride_tile_V = c::nbatch_V2 + 4;
+ constexpr int stride_tile_Q = DKQ/2 + 4;
+ constexpr int stride_tile_K = nbatch_K2 + 4;
+ static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
+ constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
extern __shared__ half2 tile_Q[];
// Preload mask and K data for first iteration when using cp_async with multiple stages:
if constexpr (nstages > 1) {
- static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
+ static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
constexpr bool use_cp_async = true;
if (ncols2 > 1 || mask_h2) {
flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>
(mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
}
flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
- (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, c::nbatch_K2, stride_K);
+ (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
}
// Iterate over ne11 == previous tokens:
for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
constexpr bool last_iter = false;
- flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
+ flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
}
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
constexpr bool last_iter = true;
- flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
+ flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
}
// It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
// So also write VKQ accumulators to shared memory in column-major format if np == 1.
- constexpr int nbatch_combine = c::Q_in_reg ? DV/2 : DV/4;
+ constexpr int nbatch_combine = c::get_nbatch_combine_device(ncols);
constexpr int tile_stride = nbatch_combine + 4;
static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine");
#endif // NEW_MMA_AVAILABLE
}
-template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap>
+template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla>
__launch_bounds__(nwarps*WARP_SIZE, 1)
static __global__ void flash_attn_ext_f16(
const char * __restrict__ Q,
NO_DEVICE_CODE;
return;
}
+#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
+ if (ncols1*ncols2 > 32) {
+ NO_DEVICE_CODE;
+ return;
+ }
+#endif __CUDA_ARCH__ == GGML_CUDA_CC_TURING
+
+ static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV");
typedef fattn_mma_f16_config<DKQ, DV> c;
const int stride_Q1 = nb01 / sizeof(float2);
const int stride_Q2 = nb02 / sizeof(float2);
const int stride_K = nb11 / sizeof(half2);
- const int stride_V = nb21 / sizeof(half2);
const int stride_mask = nb31 / sizeof(half2);
+ const int stride_V = mla ? stride_K : nb21 / sizeof(half2);
+
const int iter_k = ne11 / FATTN_KQ_STRIDE;
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
- const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
+ const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
+
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
const int kb0_start_kernel = kb0_start * kb_niter;
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
if (kb0_start == 0) {
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
- flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
+ flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
} else {
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
- flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
+ flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
}
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
- const half2 * V_h2 = (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio)); // K and V have same shape
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
+ const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
+
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
const int kb0_start_kernel = kb0_start * kb_niter;
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
constexpr bool needs_fixup = false;
- flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
+ flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
#else
typedef fattn_mma_f16_config<DKQ, DV> c;
- constexpr int nbatch_K2 = c::nbatch_K2 < 1 ? DKQ/2 : c::nbatch_K2;
- constexpr int nbatch_V2 = c::nbatch_V2 < 1 ? DV /2 : c::nbatch_V2;
- constexpr int nbatch_combine = c::nbatch_combine < 1 ? DV /2 : c::nbatch_combine;
-
const int nstages = cp_async_available(cc) ? c::nstages_target : 0;
constexpr int ncols = ncols1 * ncols2;
constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I;
constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max;
+ constexpr bool mla = DKQ == 576;
+
+ const int nbatch_K2 = c::get_nbatch_K2_host (cc, ncols);
+ const int nbatch_V2 = c::get_nbatch_K2_host (cc, ncols);
+ const int nbatch_combine = c::get_nbatch_combine_host(cc, ncols);
+
static_assert(DKQ % tile_B::J == 0, "bad DKQ");
static_assert(DV % tile_A::J == 0, "bad DV");
static_assert(ncols % cols_per_warp == 0, "bad ncols");
- const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(c::nbatch_K2 + 4, c::nbatch_V2 + 4) * sizeof(half2);
- const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (c::nbatch_K2 + 4 + c::nbatch_V2 + 4) * sizeof(half2);
- const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2);
- const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2);
- const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2);
+ const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2);
+ const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
+ const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2);
+ const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2);
+ const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2);
const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage;
fattn_kernel_t fattn_kernel;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
- fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>;
+ fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
} else {
constexpr bool use_logit_softcap = true;
- fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>;
+ fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};