const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
GGML_UNUSED(Q_v);
T sum = 0.0f;
#pragma unroll
- for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+ for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
const int k_KQ = k_KQ_0 + threadIdx.x;
const int ib = k_KQ / QI8_1;
const int shift = k_KQ & (QI8_1/2);
const int v = (get_int_b2(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
- const int u = Q_q8[k_KQ_0/WARP_SIZE];
+ const int u = Q_q8[k_KQ_0/warp_size];
const int sumi = ggml_cuda_dp4a(v, u, 0);
if (std::is_same<T, half>::value) {
const half2 * Q_ds = (const half2 *) Q_ds_v;
- const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE];
+ const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/warp_size];
sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2) /* *8/QI8_1 == 1 */);
} else
#endif // FP16_AVAILABLE
{
const float2 * Q_ds = (const float2 *) Q_ds_v;
- sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (8/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y));
+ sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (8/QI8_1)*Q_ds[k_KQ_0/warp_size].y));
}
}
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
GGML_UNUSED(Q_v);
T sum = 0.0f;
#pragma unroll
- for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+ for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
const int k_KQ = k_KQ_0 + threadIdx.x;
const int ib = k_KQ / QI8_1;
const int shift = k_KQ & (QI8_1/2);
const int v = (get_int_b4(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
- const int u = Q_q8[k_KQ_0/WARP_SIZE];
+ const int u = Q_q8[k_KQ_0/warp_size];
const int sumi = ggml_cuda_dp4a(v, u, 0);
if (std::is_same<T, half>::value) {
const half2 * Q_ds = (const half2 *) Q_ds_v;
- const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE];
+ const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/warp_size];
const half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2(sumi, 1.0f/QI8_1);
sum += (T) (__low2half(sumid4d8_m4s8scaled) + __high2half(sumid4d8_m4s8scaled));
} else
{
const float2 * Q_ds = (const float2 *) Q_ds_v;
- const float sumid4d8 = __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi;
- const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1;
+ const float sumid4d8 = __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi;
+ const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1;
sum += (T) (sumid4d8 + m4s8scaled);
}
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
GGML_UNUSED(Q_v);
T sum = 0.0f;
#pragma unroll
- for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+ for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
const int k_KQ = k_KQ_0 + threadIdx.x;
const int ib = k_KQ / QI8_1;
v |= (vh << 18) & 0x00100000; // 2 -> 20
v |= (vh << 25) & 0x10000000; // 3 -> 28
- const int u = Q_q8[k_KQ_0/WARP_SIZE];
+ const int u = Q_q8[k_KQ_0/warp_size];
const int sumi = ggml_cuda_dp4a(v, u, 0);
if (std::is_same<T, half>::value) {
const half2 * Q_ds = (const half2 *) Q_ds_v;
- const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE];
+ const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/warp_size];
sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(2.0f)) /* *16/QI8_1 == 2 */;
} else
#endif // FP16_AVAILABLE
{
const float2 * Q_ds = (const float2 *) Q_ds_v;
- sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (16/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y));
+ sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (16/QI8_1)*Q_ds[k_KQ_0/warp_size].y));
}
}
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
GGML_UNUSED(Q_v);
T sum = 0.0f;
#pragma unroll
- for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+ for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
const int k_KQ = k_KQ_0 + threadIdx.x;
const int ib = k_KQ / QI8_1;
v |= (vh << 18) & 0x00100000; // 2 -> 20
v |= (vh << 25) & 0x10000000; // 3 -> 28
- const int u = Q_q8[k_KQ_0/WARP_SIZE];
+ const int u = Q_q8[k_KQ_0/warp_size];
const int sumi = ggml_cuda_dp4a(v, u, 0);
if (std::is_same<T, half>::value) {
const half2 * Q_ds = (const half2 *) Q_ds_v;
- const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE];
+ const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/warp_size];
const half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2(sumi, 1.0f/QI8_1);
sum += (T) (__low2half(sumid5d8_m5s8scaled) + __high2half(sumid5d8_m5s8scaled));
} else
{
const float2 * Q_ds = (const float2 *) Q_ds_v;
- const float sumid5d8 = __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi;
- const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1;
+ const float sumid5d8 = __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi;
+ const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1;
sum += (T) (sumid5d8 + m5s8scaled);
}
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
GGML_UNUSED(Q_v);
T sum = 0.0f;
#pragma unroll
- for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
+ for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
const int k_KQ = k_KQ_0 + threadIdx.x;
const int ib = k_KQ / QI8_0;
T Q_d;
if (std::is_same<T, half>::value) {
const half2 * Q_ds = (const half2 *) Q_ds_v;
- Q_d = __low2half(Q_ds[k_KQ_0/WARP_SIZE]);
+ Q_d = __low2half(Q_ds[k_KQ_0/warp_size]);
} else {
const float2 * Q_ds = (const float2 *) Q_ds_v;
- Q_d = Q_ds[k_KQ_0/WARP_SIZE].x;
+ Q_d = Q_ds[k_KQ_0/warp_size].x;
}
- sum += vec_dot_q8_0_q8_1_impl<T, 1>(&v, &Q_q8[k_KQ_0/WARP_SIZE], K_q8_0[ib].d, Q_d);
+ sum += vec_dot_q8_0_q8_1_impl<T, 1>(&v, &Q_q8[k_KQ_0/warp_size], K_q8_0[ib].d, Q_d);
}
return sum;
const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
const half2 * K_h2 = (const half2 *) K_c;
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
GGML_UNUSED(Q_q8);
GGML_UNUSED(Q_ds_v);
half2 sum2 = make_half2(0.0f, 0.0f);
#pragma unroll
- for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
+ for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) {
const int k_KQ = k_KQ_0 + threadIdx.x;
const half2 K_ik = K_h2[k_KQ];
- sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE];
+ sum2 += K_ik * Q_h2[k_KQ_0/warp_size];
}
return __low2half(sum2) + __high2half(sum2);
float sum = 0.0f;
#pragma unroll
- for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
+ for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) {
const int k_KQ = k_KQ_0 + threadIdx.x;
const half2 K_ik = K_h2[k_KQ];
- sum += __low2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].x;
- sum += __high2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].y;
+ sum += __low2float(K_ik) * Q_f2[k_KQ_0/warp_size].x;
+ sum += __high2float(K_ik) * Q_f2[k_KQ_0/warp_size].y;
}
return sum;
GGML_ASSERT(Q->ne[3] == 1);
+ const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
+
ggml_cuda_pool & pool = ctx.pool();
cudaStream_t main_stream = ctx.stream();
const int id = ggml_cuda_get_device();
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
- const dim3 block_dim(WARP_SIZE, nwarps, 1);
+ const dim3 block_dim(warp_size, nwarps, 1);
dim3 blocks_num;
if (parallel_blocks == 0) {
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+ GGML_ASSERT(block_dim.x % warp_size == 0);
+ GGML_ASSERT(!GGML_CUDA_CC_IS_AMD(cc) || block_dim.x * block_dim.y <= 4 * (unsigned int)warp_size);
fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
(const char *) Q->data,
K_data,
#include "fattn-wmma-f16.cuh"
#ifdef FP16_MMA_AVAILABLE
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
#include <mma.h>
+namespace wmma = nvcuda::wmma;
+#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
+#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers
+#include <rocwmma/rocwmma.hpp>
+namespace wmma = rocwmma;
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
#endif // FP16_MMA_AVAILABLE
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
-__launch_bounds__(nwarps*WARP_SIZE, 1)
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+__launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1)
static __global__ void flash_attn_ext_f16(
const char * __restrict__ Q,
const char * __restrict__ K,
const int ne1,
const int ne2,
const int ne3) {
-#if defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE;
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
constexpr int frag_m = ncols == 8 ? 32 : 16;
constexpr int frag_n = ncols == 8 ? 8 : 16;
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
- typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
- typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
- typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
- typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
- typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
+ typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::row_major> frag_a_K;
+ typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::col_major> frag_a_V;
+ typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, half, wmma::col_major> frag_b;
+ typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
+ typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
#pragma unroll
- for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+ for (int i0 = 0; i0 < D/2; i0 += warp_size) {
const int i = i0 + threadIdx.x;
- if (i0 + WARP_SIZE > D/2 && i >= D/2) {
+ if (i0 + warp_size > D/2 && i >= D/2) {
break;
}
VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
#pragma unroll
- for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
+ for (int i0 = 0; i0 < D; i0 += warp_size) {
const int i = i0 + threadIdx.x;
- if (i0 + WARP_SIZE > D && i >= D) {
+ if (i0 + warp_size > D && i >= D) {
break;
}
KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
for (int i0 = 0; i0 < D; i0 += 16) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
- nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
+ wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
}
}
frag_c_KQ KQ_c[ncols/frag_n];
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
- nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
+ wmma::fill_fragment(KQ_c[j], static_cast<KQ_acc_t>(0.0f));
}
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
frag_a_K K_a;
- nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
+ wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
- nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
+ wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
}
}
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
- nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
+ wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, wmma::mem_col_major);
}
}
const int j = j0 + threadIdx.y;
if (std::is_same<KQ_acc_t, float>::value) {
- float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE];
+ float KQ_f_tmp[FATTN_KQ_STRIDE / warp_size];
#pragma unroll
- for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
const int k = k0 + threadIdx.x;
- KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
+ KQ_f_tmp[k0/warp_size] = KQ_f[j*kqs_padded + k];
if (use_logit_softcap) {
- KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]);
+ KQ_f_tmp[k0/warp_size] = logit_softcap*tanhf(KQ_f_tmp[k0/warp_size]);
}
}
float KQ_max_new = KQ_max_f[j0/nwarps];
#pragma unroll
- for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
const int k = k0 + threadIdx.x;
- KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
- KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
+ KQ_f_tmp[k0/warp_size] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
+ KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size]);
}
- KQ_max_new = warp_reduce_max(KQ_max_new);
+ KQ_max_new = warp_reduce_max<warp_size>(KQ_max_new);
const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
KQ_max_scale_f[j0/nwarps] = expf(diff);
float KQ_rowsum_add = 0.0f;
#pragma unroll
- for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
const int k = k0 + threadIdx.x;
- const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
- KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
+ const float diff = KQ_f_tmp[k0/warp_size] - KQ_max_f[j0/nwarps];
+ KQ_f_tmp[k0/warp_size] = expf(diff);
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
- KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
+ KQ_f_tmp[k0/warp_size] = 0.0f;
}
- KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];
- KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE];
+ KQ_rowsum_add += KQ_f_tmp[k0/warp_size];
+ KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/warp_size];
}
- KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
+ KQ_rowsum_add = warp_reduce_sum<warp_size>(KQ_rowsum_add);
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
} else {
- half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
+ half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*warp_size)];
#pragma unroll
- for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
const int k = k0 + threadIdx.x;
- KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
+ KQ2_tmp[k0/warp_size] = KQ2[j*(kqs_padded/2) + k];
if (use_logit_softcap) {
// There is no dedicated tangens hyperbolicus function for half2.
- KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f));
- KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f))
- /(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f));
+ KQ2_tmp[k0/warp_size] = h2exp(KQ2_tmp[k0/warp_size]*make_half2(2.0f, 2.0f));
+ KQ2_tmp[k0/warp_size] = (KQ2_tmp[k0/warp_size] - make_half2(1.0f, 1.0f))
+ /(KQ2_tmp[k0/warp_size] + make_half2(1.0f, 1.0f));
- KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2;
+ KQ2_tmp[k0/warp_size] *= logit_softcap_2;
}
}
half2 KQ_max_new = KQ_max_h2[j0/nwarps];
#pragma unroll
- for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
const int k = k0 + threadIdx.x;
- KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
- KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
+ KQ2_tmp[k0/warp_size] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
+ KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/warp_size]);
}
- KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
+ KQ_max_new = __half2half2(warp_reduce_max<warp_size>(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
#pragma unroll
- for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
const int k = k0 + threadIdx.x;
- const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
- KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
+ const half2 diff = KQ2_tmp[k0/warp_size] - KQ_max_h2[j0/nwarps];
+ KQ2_tmp[k0/warp_size] = h2exp(diff);
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
- *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
- KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
- KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
+ *((uint32_t *) &KQ2_tmp[k0/warp_size]) &= ftz_mask;
+ KQ_rowsum_add += KQ2_tmp[k0/warp_size];
+ KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/warp_size];
}
- KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
+ KQ_rowsum_add = warp_reduce_sum<warp_size>(KQ_rowsum_add);
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
- nvcuda::wmma::load_matrix_sync(
+ wmma::load_matrix_sync(
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
KQ + j0*(kqar*kqs_padded) + k,
kqar*kqs_padded);
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
- nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
+ wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], static_cast<half>(0.0f));
}
#pragma unroll
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
frag_a_V v_a;
- nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
+ wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
- nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
+ wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
}
}
}
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
- nvcuda::wmma::store_matrix_sync(
+ wmma::store_matrix_sync(
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
- D_padded, nvcuda::wmma::mem_col_major);
+ D_padded, wmma::mem_col_major);
}
}
}
#pragma unroll
- for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+ for (int i0 = 0; i0 < D/2; i0 += warp_size) {
const int i = i0 + threadIdx.x;
- if (i0 + WARP_SIZE > D/2 && i >= D/2) {
+ if (i0 + warp_size > D/2 && i >= D/2) {
break;
}
}
#pragma unroll
- for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
+ for (int i0 = 0; i0 < D; i0 += warp_size) {
const int i = i0 + threadIdx.x;
- if (i0 + WARP_SIZE > D && i >= D) {
+ if (i0 + warp_size > D && i >= D) {
break;
}
float dst_val = VKQ[j_VKQ*D_padded + i];
}
#else
NO_DEVICE_CODE;
-#endif // defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
}
constexpr int get_max_power_of_2(int x) {
const ggml_tensor * Q = dst->src[0];
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
+ const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
if (prec != GGML_PREC_DEFAULT) {
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
return;
}
- if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+ if (Q->ne[1] <= 8 && Q->ne[0] % warp_size == 0) {
constexpr int cols_per_block = 8;
switch (Q->ne[0]) {
case 64:
}
return;
}
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
if (Q->ne[1] <= 32) {
constexpr int cols_per_block = 16;