return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
}
+static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_cdna(const int DKQ, const int DV, const int ncols) {
+ // Conservative configs for CDNA (MI100+): 64KB LDS, wavefront64, nstages=1 (no cp.async).
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 2, 128, 32, 32, 32, 1, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 128, 2, 64, 32, 32, 32, 1, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 128, 2, 64, 32, 32, 32, 1, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 32, 32, 32, 1, true);
+
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 128, 2, 128, 40, 40, 40, 1, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 128, 2, 64, 40, 40, 40, 1, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 128, 2, 64, 40, 40, 40, 1, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 256, 2, 64, 40, 40, 40, 1, true);
+
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 128, 2, 128, 48, 48, 48, 1, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 128, 2, 64, 48, 48, 48, 1, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 128, 2, 64, 48, 48, 48, 1, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 256, 2, 64, 48, 48, 48, 1, true);
+
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 128, 2, 128, 56, 56, 56, 1, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 128, 2, 64, 56, 56, 56, 1, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2, 64, 56, 56, 56, 1, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 256, 2, 64, 56, 56, 56, 1, true);
+
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 128, 2, 128, 64, 64, 64, 1, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 128, 2, 64, 64, 64, 64, 1, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 1, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 256, 2, 64, 64, 64, 64, 1, true);
+
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 4, 64, 128, 128, 128, 1, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 4, 32, 128, 128, 128, 1, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 1, true);
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 256, 2, 32, 128, 128, 128, 1, true);
+
+ // Fallback for unsupported DKQ values (e.g. 576). Must return non-zero values to satisfy
+ // compile-time static_asserts even though the kernel guard prevents runtime execution.
+ // nthreads=256 gives nwarps=4 (warp_size=64) or 8 (warp_size=32), nbatch_fa=128 satisfies np*16 divisibility.
+ return fattn_mma_config(256, 1, 128, 4, 4, 4, 1, false);
+}
+
static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
if (ampere_mma_available(cc)) {
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
if (turing_mma_available(cc)) {
return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
}
+ if (amd_mfma_available(cc)) {
+ return ggml_cuda_fattn_mma_get_config_cdna(DKQ, DV, ncols);
+ }
if (amd_wmma_available(cc)) {
return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols);
}
return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
#elif defined(TURING_MMA_AVAILABLE)
return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
+#elif defined(AMD_MFMA_AVAILABLE)
+ return ggml_cuda_fattn_mma_get_config_cdna(DKQ, DV, ncols);
#elif defined(VOLTA_MMA_AVAILABLE)
return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
#elif defined(AMD_WMMA_AVAILABLE)
}
static constexpr __device__ int get_cols_per_thread() {
-#if defined(AMD_WMMA_AVAILABLE)
- return 1; // RDNA has a single column.
+#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+ return 1; // AMD has a single column per thread.
#else
return 2; // This is specifically KQ columns, Volta only has a single VKQ column.
-#endif // defined(AMD_WMMA_AVAILABLE)
+#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
}
static __host__ int get_cols_per_warp(const int cc) {
- if (turing_mma_available(cc) || amd_wmma_available(cc)) {
+ if (turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc)) {
return 16;
} else {
// Volta
template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_check>
static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) {
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
// K/V data is loaded with decreasing granularity for D for better memory bandwidth.
// The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
if constexpr (use_cp_async) {
const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
auto load = [&] __device__ (auto n) {
- const int stride_k = WARP_SIZE >> n;
- const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
+ const int stride_k = warp_size >> n;
+ const int k0_start = stride_k == warp_size ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
- const int stride_i = WARP_SIZE / stride_k;
+ const int stride_i = warp_size / stride_k;
if (k0_start == k0_stop) {
return;
#pragma unroll
for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
- const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+ const int i = i0 + threadIdx.y*stride_i + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
break;
#pragma unroll
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
- const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+ const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
cp_async_cg_16<preload>(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk);
}
} else {
// TODO use ggml_cuda_memcpy_1
auto load = [&] __device__ (const int n) {
- const int stride_k = WARP_SIZE >> n;
- const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k);
+ const int stride_k = warp_size >> n;
+ const int k0_start = stride_k == warp_size ? 0 : D2 - D2 % (2*stride_k);
const int k0_stop = D2 - D2 % (1*stride_k);
- const int stride_i = WARP_SIZE / stride_k;
+ const int stride_i = warp_size / stride_k;
if (k0_start == k0_stop) {
return;
#pragma unroll
for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
- const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+ const int i = i0 + threadIdx.y*stride_i + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
break;
#pragma unroll
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
- const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+ const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f);
}
static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
const half * const __restrict__ mask_h, half * const __restrict__ tile_mask,
const int stride_mask, const int i_sup, const int j0, const uint3 ne01) {
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
if constexpr (use_cp_async) {
- static_assert(nbatch_fa <= 8*WARP_SIZE && nbatch_fa % 8 == 0, "bad nbatch_fa");
+ static_assert(nbatch_fa <= 8*warp_size && nbatch_fa % 8 == 0, "bad nbatch_fa");
static_assert(!oob_check, "OOB check incompatible with cp_async");
constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64;
- constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa;
+ constexpr int cols_per_warp = 8*warp_size/nbatch_fa;
constexpr int stride_j = nwarps * cols_per_warp;
const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask);
#pragma unroll
for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
- const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp);
+ const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (warp_size/cols_per_warp);
const int j_vram = fastmodulo(j0 + j_sram, ne01);
if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
}
#pragma unroll
- for (int i0 = 0; i0 < nbatch_fa; i0 += WARP_SIZE) {
+ for (int i0 = 0; i0 < nbatch_fa; i0 += warp_size) {
const int i = i0 + threadIdx.x;
tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f);
}
}
- } else if constexpr (nbatch_fa < 2*WARP_SIZE) {
- constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa;
+ } else if constexpr (nbatch_fa < 2*warp_size) {
+ constexpr int cols_per_warp = 2*warp_size/nbatch_fa;
constexpr int stride_j = nwarps * cols_per_warp;
#pragma unroll
for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
- const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp);
+ const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (warp_size/cols_per_warp);
const int j_vram = fastmodulo(j0 + j_sram, ne01);
if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
break;
}
- const int i = threadIdx.x % (WARP_SIZE/cols_per_warp);
+ const int i = threadIdx.x % (warp_size/cols_per_warp);
ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i);
}
}
#pragma unroll
- for (int i0 = 0; i0 < nbatch_fa; i0 += 2*WARP_SIZE) {
+ for (int i0 = 0; i0 < nbatch_fa; i0 += 2*warp_size) {
const int i = i0 + 2*threadIdx.x;
ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i);
const int jt,
const int kb0,
const int k_VKQ_sup) {
-#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
+#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr int ncols = ncols1 * ncols2;
constexpr int cols_per_warp = T_B_KQ::I;
constexpr int cols_per_thread = get_cols_per_thread();
const int k_VKQ_0 = kb0 * nbatch_fa;
#if defined(TURING_MMA_AVAILABLE)
T_C_KQ KQ_C[nbatch_fa/(np*(cols_per_warp == 8 ? T_C_KQ::I : T_C_KQ::J))];
-#elif defined(AMD_WMMA_AVAILABLE)
+#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
#else // Volta
T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
} else {
// Wide version of KQ_C is column-major
-#if defined(AMD_WMMA_AVAILABLE)
- // RDNA matrix C is column-major.
+#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+ // AMD matrix C is column-major.
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
#else
// swap A and B for CUDA.
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[k_KQ_0/T_A_KQ::J], K_A);
-#endif // defined(AMD_WMMA_AVAILABLE)
+#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
}
}
}
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
} else {
// Wide version of KQ_C is column-major
-#if defined(AMD_WMMA_AVAILABLE)
- // RDNA matrix C is column-major.
+#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+ // AMD matrix C is column-major.
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
#else
// swap A and B for CUDA.
mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
-#endif // defined(AMD_WMMA_AVAILABLE)
+#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
}
}
}
#pragma unroll
for (int l = 0; l < T_C_KQ::ne; ++l) {
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
-#if defined(AMD_WMMA_AVAILABLE)
+#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
constexpr int KQ_idx = 0;
#else
// Turing + Volta:
const int KQ_idx = l % 2;
-#endif // defined(AMD_WMMA_AVAILABLE)
+#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
}
}
for (int col = 0; col < cols_per_thread; ++col) {
#pragma unroll
for (int offset = 16; offset >= 4; offset >>= 1) {
- KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
+ KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, warp_size));
}
}
#pragma unroll
for (int l = 0; l < T_C_KQ::ne; ++l) {
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
-#if defined(AMD_WMMA_AVAILABLE)
+#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
constexpr int KQ_idx = 0;
#else
// Turing + Volta:
const int KQ_idx = l % 2;
-#endif // defined(AMD_WMMA_AVAILABLE)
+#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[KQ_idx]);
KQ_rowsum_add[KQ_idx] += KQ_C[k0/(np*T_C_KQ::I)].x[l];
} else {
#pragma unroll
for (int l = 0; l < T_C_KQ::ne; ++l) {
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
-#if defined(AMD_WMMA_AVAILABLE)
+#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
constexpr int KQ_idx = 0;
#else
// Turing + Volta:
const int KQ_idx = (l/2) % 2;
-#endif // defined(AMD_WMMA_AVAILABLE)
+#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
}
}
// Values per KQ column are spread across 4 threads:
constexpr int offset_first = 2;
constexpr int offset_last = 1;
+#elif defined(AMD_MFMA_AVAILABLE)
+ // MFMA: 4 threads per Q column (threadIdx.x % 16 == col, spaced by 16).
+ constexpr int offset_first = 32;
+ constexpr int offset_last = 16;
#elif defined(AMD_WMMA_AVAILABLE)
// Values per KQ column are spread across 2 threads:
constexpr int offset_first = 16;
#endif // defined(TURING_MMA_AVAILABLE)
#pragma unroll
for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
- KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
+ KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, warp_size));
}
}
#pragma unroll
for (int l = 0; l < T_C_KQ::ne; ++l) {
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
-#if defined(AMD_WMMA_AVAILABLE)
+#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
constexpr int KQ_idx = 0;
#else
// Turing + Volta:
const int KQ_idx = (l/2) % 2;
-#endif // defined(AMD_WMMA_AVAILABLE)
+#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[KQ_idx]);
KQ_rowsum_add[KQ_idx] += KQ_C[(k0/(np*T_C_KQ::J))].x[l];
} else {
}
}
}
-#elif defined(AMD_WMMA_AVAILABLE)
+#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
const half2 KQ_max_scale_h2 = make_half2(
KQ_max_scale[0], KQ_max_scale[0]);
#pragma unroll
}
const half2 * tile_V_i = !V_is_K_view || i0_stop > 2*nbatch_K2 ? tile_V : tile_V + i0_start/2;
-#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J;
#pragma unroll
for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) {
T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load.
#if defined(LDMATRIX_TRANS_AVAILABLE)
load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
+#elif defined(AMD_MFMA_AVAILABLE)
+ // MFMA A register layout: A_mat[i=lane%16][k=4*(lane/16)+reg].
+ // Normal load gives A_mat[seq][dv] but we need A_mat[dv][seq] = V^T.
+ // Load with transposed addressing: 4 strided half loads.
+ {
+ const half2 * xs0 = tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2;
+ const half * xs0_h = (const half *) xs0;
+ const int stride_h = stride_tile_V * 2; // stride in half units
+ half * A_h = (half *) A.x;
+#pragma unroll
+ for (int l = 0; l < 4; ++l) {
+ A_h[l] = xs0_h[(4*(threadIdx.x / 16) + l) * stride_h + threadIdx.x % 16];
+ }
+ }
#else
// TODO: Try to transpose tile_V when loading gmem to smem.
// Use mma to transpose T_A_VKQ for RDNA.
T_A_VKQ A_trans;
load_ldmatrix(A_trans, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
mma(A, A_trans, A_identity);
-#endif // defined(TURING_MMA_AVAILABLE)
+#endif // defined(LDMATRIX_TRANS_AVAILABLE)
if constexpr (T_B_KQ::I == 8) {
mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
} else {
// Wide version of VKQ_C is column-major.
-#if defined(AMD_WMMA_AVAILABLE)
- // RDNA matrix C is column-major.
+#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
+ // AMD matrix C is column-major.
mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
#else
// swap A and B for CUDA.
mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A);
-#endif // defined(AMD_WMMA_AVAILABLE)
+#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
}
}
}
mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::I)], A);
}
}
-#endif // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+#endif // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
if constexpr (nstages <= 1) {
__syncthreads(); // Only needed if tile_K == tile_V.
tile_Q, tile_K, tile_V, tile_mask,
Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
NO_DEVICE_CODE;
-#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
+#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
}
#if defined(TURING_MMA_AVAILABLE)
using T_B_VKQ = tile< 8, 8, half2>; // column-major
using T_C_VKQ = tile<16, 4, half2>; // row-major
};
-#elif defined(AMD_WMMA_AVAILABLE)
+#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
template<int ncols> struct mma_tile_sizes {
using T_A_KQ = tile<16, 8, half2>; // row-major
using T_B_KQ = tile<16, 8, half2>; // column-major
const int zt_gqa,
const int kb0_start,
const int kb0_stop) {
-#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
+#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr int ncols = ncols1 * ncols2;
using T_A_KQ = typename mma_tile_sizes<ncols>::T_A_KQ;
using T_B_KQ = typename mma_tile_sizes<ncols>::T_B_KQ;
T_B_KQ Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)];
#if defined(TURING_MMA_AVAILABLE)
T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)];
-#elif defined(AMD_WMMA_AVAILABLE)
+#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)];
#else // Volta
T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)];
// The loading is done with decreasing granularity for D for better memory bandwidth.
const half2 scale_h2 = make_half2(scale, scale);
#pragma unroll
- for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
- const int k0_start = stride_k == WARP_SIZE ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k);
+ for (int stride_k : {warp_size, warp_size/2, warp_size/4, warp_size/8}) {
+ const int k0_start = stride_k == warp_size ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k);
const int k0_stop = DKQ/2 - (DKQ/2) % (1*stride_k);
- const int stride_jc = WARP_SIZE / stride_k;
+ const int stride_jc = warp_size / stride_k;
if (k0_start == k0_stop) {
continue;
#pragma unroll
for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) {
- const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+ const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) {
break;
if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt_gqa*ncols2 + c < gqa_ratio)) {
#pragma unroll
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
- const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+ const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k];
tile_Q[jc*stride_tile_Q + k] = scale_h2 * make_half2(tmp.x, tmp.y);
} else {
#pragma unroll
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
- const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+ const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
tile_Q[jc*stride_tile_Q + k] = make_half2(0.0f, 0.0f);
}
// The partial sums are spread across 8/4 threads.
constexpr int offset_first = cols_per_warp == 8 ? 16 : 2;
constexpr int offset_last = cols_per_warp == 8 ? 4 : 1;
+#elif defined(AMD_MFMA_AVAILABLE)
+ // The partial sums are spread across 4 threads (wavefront64, 16 cols).
+ constexpr int offset_first = 32;
+ constexpr int offset_last = 16;
#elif defined(AMD_WMMA_AVAILABLE)
// The partial sums are spread across 2 threads.
constexpr int offset_first = 16;
for (int col = 0; col < cols_per_thread; ++col) {
#pragma unroll
for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
- KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE);
+ KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, warp_size);
}
}
}
}
}
}
-#elif defined(AMD_WMMA_AVAILABLE)
+#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]);
#pragma unroll
for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(threadIdx.x % 4);
const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]);
const bool thread_should_write = threadIdx.x % 4 < cols_per_thread;
-#elif defined(AMD_WMMA_AVAILABLE)
+#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(0);
const float2 KQ_cmr = make_float2(KQ_max[0], KQ_rowsum[0]);
const bool thread_should_write = threadIdx.x / 16 < cols_per_thread;
// Warps with threadIdx.y % np != 0 must NOT return early.
// All threads must return simultaneously to avoid race conditions with work on the next tile.
- constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1;
+ constexpr int nmeta = np*cols_per_warp >= warp_size ? np*cols_per_warp/warp_size : 1;
- const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
+ const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < warp_size ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2;
float2 meta[nmeta];
#pragma unroll
for (int imeta = 0; imeta < nmeta; ++imeta) {
- meta[imeta] = meta_ptr[imeta * WARP_SIZE * tile_stride/2];
+ meta[imeta] = meta_ptr[imeta * warp_size * tile_stride/2];
}
float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps.
}
#pragma unroll
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
- if (offset < WARP_SIZE) {
- KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
+ if (offset < warp_size) {
+ KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, warp_size));
}
}
}
#pragma unroll
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
- if (offset < WARP_SIZE) {
- KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
+ if (offset < warp_size) {
+ KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, warp_size);
}
}
// Write back combined meta data:
#pragma unroll
for (int imeta = 0; imeta < nmeta; ++imeta) {
- if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) {
+ if (np*cols_per_warp >= warp_size || threadIdx.x < np*cols_per_warp) {
// Combined KQ max scale + rowsum.
- meta_ptr[imeta * WARP_SIZE * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs);
+ meta_ptr[imeta * warp_size * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs);
}
}
// Combined KQ max + rowsum.
- static_assert(cols_per_warp <= WARP_SIZE);
- if (needs_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
+ static_assert(cols_per_warp <= warp_size);
+ if (needs_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) {
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
}
- if (is_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
+ if (is_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) {
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
}
float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(DV/2));
#pragma unroll
- for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
- const int k0_start = stride_k == WARP_SIZE ? 0 : nbatch_combine - nbatch_combine % (2*stride_k);
+ for (int stride_k : {warp_size, warp_size/2, warp_size/4, warp_size/8}) {
+ const int k0_start = stride_k == warp_size ? 0 : nbatch_combine - nbatch_combine % (2*stride_k);
const int k0_stop = nbatch_combine - nbatch_combine % (1*stride_k);
- const int stride_jc = WARP_SIZE / stride_k;
+ const int stride_jc = warp_size / stride_k;
if (k0_start == k0_stop) {
continue;
#pragma unroll
for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) {
- const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
+ const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) {
break;
const float * meta_j = (const float *) tile_Q + jc_tile_K*tile_stride + nbatch_combine;
#pragma unroll
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
- const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
+ const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
float2 dstk_val = make_float2(0.0f, 0.0f);
#pragma unroll
stride_Q1, stride_Q2, stride_K, stride_V, stride_mask,
jt, kb0_start, kb0_stop);
NO_DEVICE_CODE;
-#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))
+#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
}
template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool V_is_K_view>
const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
-#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)))
+#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE))
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
}
#endif // defined(AMD_WMMA_AVAILABLE)
+#if defined(AMD_MFMA_AVAILABLE)
+ if (DKQ != 64 && DKQ != 80 && DKQ != 96 && DKQ != 112 && DKQ != 128) {
+ NO_DEVICE_CODE;
+ return;
+ }
+#endif // defined(AMD_MFMA_AVAILABLE)
+
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr int ncols = ncols1 * ncols2;
constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
constexpr int nthreads = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols);
- constexpr int nwarps = nthreads / WARP_SIZE;
+ constexpr int nwarps = nthreads / warp_size;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
ne31, ne32, ne33,
nb31, nb32, nb33);
NO_DEVICE_CODE;
-#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)))
+#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE))
}
template <int DKQ, int DV, int ncols1, int ncols2>
const int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2, cc);
const int cols_per_warp = std::min(ncols, get_cols_per_warp(cc));
- const int nwarps = nthreads / WARP_SIZE;
+ const int warp_size_host = ggml_cuda_info().devices[ctx.device].warp_size;
+ const int nwarps = nthreads / warp_size_host;
constexpr bool V_is_K_view = DKQ == 576; // Guaranteed by the kernel selection logic in fattn.cu
}
launch_fattn<DV, ncols1, ncols2>
- (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true);
+ (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true, warp_size_host);
}