From: Johannes Gäßler Date: Sat, 6 Sep 2025 22:26:28 +0000 (+0200) Subject: CUDA: faster tile FA (Pascal/AMD), headsize 256 (#15769) X-Git-Tag: upstream/0.0.6527~125 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=79bc429262268ad2ac8a364cfe6c2d6b9c5f008a;p=pkg%2Fggml%2Fsources%2Fllama.cpp CUDA: faster tile FA (Pascal/AMD), headsize 256 (#15769) --- diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu deleted file mode 100644 index a900799a..00000000 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ /dev/null @@ -1,371 +0,0 @@ -#include "common.cuh" -#include "fattn-common.cuh" -#include "fattn-tile-f16.cuh" - -#define FATTN_KQ_STRIDE_TILE_F16 64 - -template // D == head size -#if !defined(GGML_USE_HIP) -__launch_bounds__(nwarps*WARP_SIZE, 2) -#endif // !defined(GGML_USE_HIP) -static __global__ void flash_attn_tile_ext_f16( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, - const float scale, - const float max_bias, - const float m0, - const float m1, - const uint32_t n_head_log2, - const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, - const int32_t nb01, const int32_t nb02, const int32_t nb03, - const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, - const int32_t nb11, const int32_t nb12, const int64_t nb13, - const int32_t nb21, const int32_t nb22, const int64_t nb23, - const int32_t ne31, const int32_t ne32, const int32_t ne33, - const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) - - // Skip unused kernel variants for faster compilation: -#ifdef FP16_MMA_AVAILABLE - NO_DEVICE_CODE; - return; -#endif // FP16_MMA_AVAILABLE - if (use_logit_softcap && !(D == 128 || D == 256)) { - NO_DEVICE_CODE; - return; - } - - //In this kernel Q, K, V are matrices while i, j, k are matrix indices. - - const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. - - const int sequence = blockIdx.z / ne02; - const int head = blockIdx.z - sequence*ne02; - const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0); - const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); - const float * sinksf = (const float *) (sinks); - - const int stride_KV2 = nb11 / sizeof(half2); - - const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); - const half slopeh = __float2half(slopef); - - static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); - - __shared__ half KQ[ncols*FATTN_KQ_STRIDE_TILE_F16]; - half2 * KQ2 = (half2 *) KQ; - - __shared__ half2 KV_tmp[FATTN_KQ_STRIDE_TILE_F16][D/2 + 1]; // Pad D to avoid memory bank conflicts. - - half kqmax[ncols/nwarps]; -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - kqmax[j0/nwarps] = -HALF_MAX_HALF; - } - half2 kqsum[ncols/nwarps] = {{0.0f, 0.0f}}; - - half2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}}; - - // Convert Q to half2 and store in registers: - __shared__ half2 Q_h2[ncols][D/2]; -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - const float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f); - Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y); - } - } - - __syncthreads(); - - const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; - for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) { - // Calculate KQ tile and keep track of new maximum KQ values: - - half kqmax_new[ncols/nwarps]; -#pragma unroll - for (int j = 0; j < ncols/nwarps; ++j) { - kqmax_new[j] = kqmax[j]; - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += nwarps) { - const int i_KQ = i_KQ_0 + threadIdx.y; - -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { - const int k_KQ = k_KQ_0 + threadIdx.x; - - KV_tmp[i_KQ][k_KQ] = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ]; - } - } - - __syncthreads(); - - half2 sum2[FATTN_KQ_STRIDE_TILE_F16/WARP_SIZE][ncols/nwarps] = {{{0.0f, 0.0f}}}; - -#pragma unroll - for (int k_KQ = 0; k_KQ < D/2; ++k_KQ) { - half2 K_k[FATTN_KQ_STRIDE_TILE_F16/WARP_SIZE]; - half2 Q_k[ncols/nwarps]; - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) { - const int i_KQ = i_KQ_0 + threadIdx.x; - - K_k[i_KQ_0/WARP_SIZE] = KV_tmp[i_KQ][k_KQ]; - } -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { - const int j_KQ = j_KQ_0 + threadIdx.y; - - Q_k[j_KQ_0/nwarps] = Q_h2[j_KQ][k_KQ]; - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) { -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { - sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += K_k[i_KQ_0/WARP_SIZE]*Q_k[j_KQ_0/nwarps]; - } - } - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) { - const int i_KQ = i_KQ_0 + threadIdx.x; - -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { - const int j_KQ = j_KQ_0 + threadIdx.y; - - half sum; - if (use_logit_softcap) { - const float2 tmp = __half22float2(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); - sum = logit_softcap * tanhf(tmp.x + tmp.y); - } else { - sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); - } - sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); - - kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum); - - KQ[j_KQ*FATTN_KQ_STRIDE_TILE_F16 + i_KQ] = sum; - } - } - - __syncthreads(); - -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - kqmax_new[j0/nwarps] = warp_reduce_max(kqmax_new[j0/nwarps]); - const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new[j0/nwarps])); - kqmax[j0/nwarps] = kqmax_new[j0/nwarps]; - -#pragma unroll - for (int i0 = 0; i0 < FATTN_KQ_STRIDE_TILE_F16/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - const half2 diff = KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + i] - __half2half2(kqmax[j0/nwarps]); - const half2 val = h2exp(diff); - kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + val; - KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + i] = val; - } - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale; - } - } - - __syncthreads(); - -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F16; k0 += nwarps) { - const int k = k0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - KV_tmp[k][i] = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i]; - } - } - - __syncthreads(); - -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F16; k0 += 2) { - half2 V_k[(D/2)/WARP_SIZE][2]; - half2 KQ_k[ncols/nwarps]; - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - V_k[i0/WARP_SIZE][0] = KV_tmp[k0 + 0][i]; - V_k[i0/WARP_SIZE][1] = KV_tmp[k0 + 1][i]; - } -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - KQ_k[j0/nwarps] = KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + k0/2]; - } - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - VKQ[j0/nwarps][i0/WARP_SIZE] += V_k[i0/WARP_SIZE][0]* __low2half2(KQ_k[j0/nwarps]); - VKQ[j0/nwarps][i0/WARP_SIZE] += V_k[i0/WARP_SIZE][1]*__high2half2(KQ_k[j0/nwarps]); - } - } - } - - __syncthreads(); - } - - //Attention sink: adjust running max and sum once per head - if (sinksf && blockIdx.y == 0) { - const half sink = __float2half(sinksf[head]); - -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - half kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink); - kqmax_new_j = warp_reduce_max(kqmax_new_j); - - const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new_j)); - kqmax[j0/nwarps] = kqmax_new_j; - - const half val = hexp(sink - kqmax[j0/nwarps]); - kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale; - if (threadIdx.x == 0) { - kqsum[j0/nwarps].x = __hadd(__low2half(kqsum[j0/nwarps]), val); - } - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale; - } - } - } - - float2 * dst2 = (float2 *) dst; - -#pragma unroll - for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) { - const int j_VKQ = j_VKQ_0 + threadIdx.y; - - if (ic0 + j_VKQ >= ne01) { - return; - } - - half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]); - kqsum_j = warp_reduce_sum((float)kqsum_j); - - const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; - -#pragma unroll - for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) { - const int i0 = i00 + threadIdx.x; - - half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE]; - if (gridDim.y == 1) { - dst_val /= __half2half2(kqsum_j); - } - dst2[j_dst_unrolled*(D/2) + i0] = __half22float2(dst_val); - } - - if (gridDim.y != 1 && threadIdx.x == 0) { - dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); - } - } -#else - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, - max_bias, m0, m1, n_head_log2, logit_softcap, - ne00, ne01, ne02, ne03, - nb01, nb02, nb03, - ne10, ne11, ne12, ne13, - nb11, nb12, nb13, - nb21, nb22, nb23, - ne31, ne32, ne33, - nb31, nb32, nb33); - NO_DEVICE_CODE; -#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) -} - -template -void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * Q = dst->src[0]; - switch (Q->ne[0]) { - case 64: { - constexpr int D = 64; - constexpr int nwarps = 8; - constexpr size_t nbytes_shared = 0; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false); - } break; - case 128: { - constexpr int D = 128; - constexpr int nwarps = 8; - constexpr size_t nbytes_shared = 0; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16; - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false); - } break; - default: { - GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128."); - } break; - } -} - -void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * KQV = dst; - const ggml_tensor * Q = dst->src[0]; - - const int32_t precision = KQV->op_params[3]; - GGML_ASSERT(precision == GGML_PREC_DEFAULT); - - float logit_softcap; - memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - - if (Q->ne[1] <= 16) { - constexpr int cols_per_block = 16; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - launch_fattn_tile_f16_64_128(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - launch_fattn_tile_f16_64_128(ctx, dst); - } - return; - } - - constexpr int cols_per_block = 32; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - launch_fattn_tile_f16_64_128(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - launch_fattn_tile_f16_64_128(ctx, dst); - } -} diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cuh b/ggml/src/ggml-cuda/fattn-tile-f16.cuh deleted file mode 100644 index ffc58784..00000000 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cuh +++ /dev/null @@ -1,3 +0,0 @@ -#include "common.cuh" - -void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu deleted file mode 100644 index b96a9ef9..00000000 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ /dev/null @@ -1,379 +0,0 @@ -#include "common.cuh" -#include "fattn-common.cuh" -#include "fattn-tile-f32.cuh" - -#define FATTN_KQ_STRIDE_TILE_F32 32 - -template // D == head size -#if !defined(GGML_USE_HIP) -__launch_bounds__(nwarps*WARP_SIZE, 2) -#endif // !defined(GGML_USE_HIP) -static __global__ void flash_attn_tile_ext_f32( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, - const float scale, - const float max_bias, - const float m0, - const float m1, - const uint32_t n_head_log2, - const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, - const int32_t nb01, const int32_t nb02, const int32_t nb03, - const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, - const int32_t nb11, const int32_t nb12, const int64_t nb13, - const int32_t nb21, const int32_t nb22, const int64_t nb23, - const int32_t ne31, const int32_t ne32, const int32_t ne33, - const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#ifdef FLASH_ATTN_AVAILABLE - - // Skip unused kernel variants for faster compilation: -#ifdef FP16_MMA_AVAILABLE - NO_DEVICE_CODE; - return; -#endif // FP16_MMA_AVAILABLE - if (use_logit_softcap && !(D == 128 || D == 256)) { - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, - max_bias, m0, m1, n_head_log2, logit_softcap, - ne00, ne01, ne02, ne03, - nb01, nb02, nb03, - ne10, ne11, ne12, ne13, - nb11, nb12, nb13, - nb21, nb22, nb23, - ne31, ne32, ne33, - nb31, nb32, nb33); - NO_DEVICE_CODE; - return; - } - - // In this kernel Q, K, V are matrices while i, j, k are matrix indices. - - const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. - - const int sequence = blockIdx.z / ne02; - const int head = blockIdx.z - sequence*ne02; - const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0); - const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); - const float * sinksf = (const float *) (sinks); - - const int stride_KV2 = nb11 / sizeof(half2); - - const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); - - static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); - - __shared__ float KQ[ncols*FATTN_KQ_STRIDE_TILE_F32]; - - __shared__ float KV_tmp[FATTN_KQ_STRIDE_TILE_F32][D + 1]; // Pad D to avoid memory bank conflicts. - float2 * KV_tmp2 = (float2 *) KV_tmp; - - float kqmax[ncols/nwarps]; -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - kqmax[j0/nwarps] = -FLT_MAX/2.0f; - } - float kqsum[ncols/nwarps] = {0.0f}; - - float2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}}; - - // Convert Q to half2 and store in registers: - __shared__ float Q_f[ncols][D]; -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) { - float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x] : make_float2(0.0f, 0.0f); - Q_f[j][i0 + 0*WARP_SIZE + threadIdx.x] = tmp.x * scale; - Q_f[j][i0 + 1*WARP_SIZE + threadIdx.x] = tmp.y * scale; - } - } - - __syncthreads(); - - const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; - for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F32; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F32) { - // Calculate KQ tile and keep track of new maximum KQ values: - - float kqmax_new[ncols/nwarps]; -#pragma unroll - for (int j = 0; j < ncols/nwarps; ++j) { - kqmax_new[j] = kqmax[j]; - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += nwarps) { - const int i_KQ = i_KQ_0 + threadIdx.y; - -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) { - const half2 tmp = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x]; - KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] = __low2float(tmp); - KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp); - } - } - - __syncthreads(); - - float sum[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE][ncols/nwarps] = {{0.0f}}; - -#pragma unroll - for (int k_KQ = 0; k_KQ < D; ++k_KQ) { - float K_k[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE]; - float Q_k[ncols/nwarps]; - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) { - const int i_KQ = i_KQ_0 + threadIdx.x; - - K_k[i_KQ_0/WARP_SIZE] = KV_tmp[i_KQ][k_KQ]; - } -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { - const int j_KQ = j_KQ_0 + threadIdx.y; - - Q_k[j_KQ_0/nwarps] = Q_f[j_KQ][k_KQ]; - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) { -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { - sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += K_k[i_KQ_0/WARP_SIZE] * Q_k[j_KQ_0/nwarps]; - } - } - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) { - const int i_KQ = i_KQ_0 + threadIdx.x; - -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { - const int j_KQ = j_KQ_0 + threadIdx.y; - - if (use_logit_softcap) { - sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] = logit_softcap * tanhf(sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); - } - - sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; - - kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]); - - KQ[j_KQ*FATTN_KQ_STRIDE_TILE_F32 + i_KQ] = sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]; - } - } - - __syncthreads(); - -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - kqmax_new[j0/nwarps] = warp_reduce_max(kqmax_new[j0/nwarps]); - const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]); - kqmax[j0/nwarps] = kqmax_new[j0/nwarps]; - - float kqsum_add = 0.0f; -#pragma unroll - for (int i0 = 0; i0 < FATTN_KQ_STRIDE_TILE_F32; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - const float diff = KQ[j*FATTN_KQ_STRIDE_TILE_F32 + i] - kqmax[j0/nwarps]; - const float val = expf(diff); - kqsum_add += val; - KQ[j*FATTN_KQ_STRIDE_TILE_F32 + i] = val; - } - kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add; - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale; - VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale; - } - } - - __syncthreads(); - -#pragma unroll - for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F32; k0 += nwarps) { - const int k = k0 + threadIdx.y; - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - const half2 tmp = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i]; - KV_tmp2[k*(D/2) + i].x = __low2float(tmp); - KV_tmp2[k*(D/2) + i].y = __high2float(tmp); - } - } - - __syncthreads(); - -#pragma unroll - for (int k = 0; k < FATTN_KQ_STRIDE_TILE_F32; ++k) { - float2 V_k[(D/2)/WARP_SIZE]; - float KQ_k[ncols/nwarps]; - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - const int i = i0 + threadIdx.x; - - V_k[i0/WARP_SIZE] = KV_tmp2[k*(D/2) + i]; - } -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - const int j = j0 + threadIdx.y; - - KQ_k[j0/nwarps] = KQ[j*FATTN_KQ_STRIDE_TILE_F32 + k]; - } - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - VKQ[j0/nwarps][i0/WARP_SIZE].x += V_k[i0/WARP_SIZE].x*KQ_k[j0/nwarps]; - VKQ[j0/nwarps][i0/WARP_SIZE].y += V_k[i0/WARP_SIZE].y*KQ_k[j0/nwarps]; - } - } - } - - __syncthreads(); - } - - - //Attention sink: adjust running max and sum once per head - if (sinksf && blockIdx.y == 0) { - const float sink = sinksf[head]; - -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - float kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink); - kqmax_new_j = warp_reduce_max(kqmax_new_j); - - const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new_j); - kqmax[j0/nwarps] = kqmax_new_j; - - const float val = expf(sink - kqmax[j0/nwarps]); - kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale; - if (threadIdx.x == 0) { - kqsum[j0/nwarps] += val; - } - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { - VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale; - VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale; - } - } - } - - float2 * dst2 = (float2 *) dst; - -#pragma unroll - for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) { - const int j_VKQ = j_VKQ_0 + threadIdx.y; - - if (ic0 + j_VKQ >= ne01) { - return; - } - - float kqsum_j = kqsum[j_VKQ_0/nwarps]; - kqsum_j = warp_reduce_sum(kqsum_j); - - const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; - -#pragma unroll - for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) { - const int i0 = i00 + threadIdx.x; - - float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE]; - if (gridDim.y == 1) { - dst_val.x /= kqsum_j; - dst_val.y /= kqsum_j; - } - dst2[j_dst_unrolled*(D/2) + i0] = dst_val; - } - - if (gridDim.y != 1 && threadIdx.x == 0) { - dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); - } - } -#else - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, - max_bias, m0, m1, n_head_log2, logit_softcap, - ne00, ne01, ne02, ne03, - nb01, nb02, nb03, - ne10, ne11, ne12, ne13, - nb11, nb12, nb13, - nb21, nb22, nb23, - ne31, ne32, ne33, - nb31, nb32, nb33); - NO_DEVICE_CODE; -#endif // FLASH_ATTN_AVAILABLE -} - -template -void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * Q = dst->src[0]; - switch (Q->ne[0]) { - case 64: { - constexpr int D = 64; - constexpr int nwarps = 8; - constexpr size_t nbytes_shared = 0; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false); - } break; - case 128: { - constexpr int D = 128; - constexpr int nwarps = 8; - constexpr size_t nbytes_shared = 0; - fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32; - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false); - } break; - default: { - GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128."); - } break; - } -} - -void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * KQV = dst; - const ggml_tensor * Q = dst->src[0]; - - float logit_softcap; - memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - - if (Q->ne[1] <= 16) { - constexpr int cols_per_block = 16; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - launch_fattn_tile_f32_64_128(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - launch_fattn_tile_f32_64_128(ctx, dst); - } - return; - } - - constexpr int cols_per_block = 32; - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - launch_fattn_tile_f32_64_128(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - launch_fattn_tile_f32_64_128(ctx, dst); - } -} diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cuh b/ggml/src/ggml-cuda/fattn-tile-f32.cuh deleted file mode 100644 index b1c546c8..00000000 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cuh +++ /dev/null @@ -1,3 +0,0 @@ -#include "common.cuh" - -void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/fattn-tile.cu b/ggml/src/ggml-cuda/fattn-tile.cu new file mode 100644 index 00000000..fb2163ac --- /dev/null +++ b/ggml/src/ggml-cuda/fattn-tile.cu @@ -0,0 +1,596 @@ +#include "common.cuh" +#include "fattn-common.cuh" +#include "fattn-tile.cuh" + +#define FATTN_TILE_NTHREADS 256 + +static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int cc, const int warp_size) { + if (GGML_CUDA_CC_IS_AMD(cc)) { + switch (D) { + case 64: + return ncols <= 16 ? 32 : 64; + case 128: + return ncols <= 16 ? 64 : warp_size; + case 256: + return 64; + default: + GGML_ABORT("fatal error"); + return -1; + } + } + if (fast_fp16_available(cc)) { + switch (D) { + case 64: + case 128: + return 128; + case 256: + return ncols <= 16 ? 128 : 64; + default: + GGML_ABORT("fatal error"); + return -1; + } + } + switch (D) { + case 64: + return ncols <= 16 ? 128 : 64; + case 128: + return ncols <= 16 ? 64 : 32; + case 256: + return 32; + default: + GGML_ABORT("fatal error"); + return -1; + } +} + +static constexpr __device__ int fattn_tile_get_kq_stride_device(int D, int ncols, int warp_size) { +#ifdef GGML_USE_HIP + switch (D) { + case 64: + return ncols <= 16 ? 32 : 64; + case 128: + return ncols <= 16 ? 64 : warp_size; + case 256: + return 64; + default: + return -1; + } +#else +#ifdef FAST_FP16_AVAILABLE + switch (D) { + case 64: + case 128: + return 128; + case 256: + return ncols <= 16 ? 128 : 64; + default: + return -1; + } +#else + switch (D) { + case 64: + return ncols <= 16 ? 128 : 64; + case 128: + return ncols <= 16 ? 64 : 32; + case 256: + return 32; + default: + return -1; + } +#endif // FAST_FP16_AVAILABLE +#endif // GGML_USE_HIP + GGML_UNUSED_VARS(ncols, warp_size); +} + +static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols, int warp_size) { +#ifdef GGML_USE_HIP + switch (D) { + case 64: + return 64; + case 128: + return ncols <= 16 ? 2*warp_size : 128; + case 256: + return ncols <= 16 ? 128 : 2*warp_size; + default: + return -1; + } +#else +#ifdef FAST_FP16_AVAILABLE + switch (D) { + case 64: + return 64; + case 128: + return ncols <= 16 ? 128 : 64; + case 256: + return ncols <= 16 ? 64 : 128; + default: + return -1; + } +#else + switch (D) { + case 64: + return 64; + case 128: + return 128; + case 256: + return ncols <= 16 ? 128 : 64; + default: + return -1; + } +#endif // FAST_FP16_AVAILABLE +#endif // GGML_USE_HIP + GGML_UNUSED_VARS(ncols, warp_size); +} + +template // D == head size +#ifdef GGML_USE_HIP +__launch_bounds__(FATTN_TILE_NTHREADS, 1) +#else +__launch_bounds__(FATTN_TILE_NTHREADS, 2) +#endif // GGML_USE_HIP +static __global__ void flash_attn_tile( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + const char * __restrict__ sinks, + const int * __restrict__ KV_max, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const float logit_softcap, + const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t nb01, const int32_t nb02, const int32_t nb03, + const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, + const int32_t nb11, const int32_t nb12, const int64_t nb13, + const int32_t nb21, const int32_t nb22, const int64_t nb23, + const int32_t ne31, const int32_t ne32, const int32_t ne33, + const int32_t nb31, const int32_t nb32, const int64_t nb33) { +#ifdef FLASH_ATTN_AVAILABLE + + // Skip unused kernel variants for faster compilation: +#ifdef FP16_MMA_AVAILABLE + NO_DEVICE_CODE; + return; +#endif // FP16_MMA_AVAILABLE + + if (use_logit_softcap && !(D == 128 || D == 256)) { + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); + NO_DEVICE_CODE; + return; + } + + constexpr int warp_size = 32; + constexpr int nwarps = FATTN_TILE_NTHREADS / warp_size; + constexpr int kq_stride = fattn_tile_get_kq_stride_device(D, ncols, warp_size); + static_assert(kq_stride % warp_size == 0, "kq_stride not divisable by warp_size."); + constexpr int kq_nbatch = fattn_tile_get_kq_nbatch_device(D, ncols, warp_size); + static_assert(kq_nbatch % (2*warp_size) == 0, "bad kq_nbatch"); + + // In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. + + const int sequence = blockIdx.z / ne02; + const int head = blockIdx.z - sequence*ne02; + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0); + const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape + const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); + const float * sinksf = (const float *) (sinks); + + const int stride_KV2 = nb11 / sizeof(half2); + + const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); + + __shared__ float KQ[ncols][kq_stride]; +#ifdef FAST_FP16_AVAILABLE + __shared__ half2 Q_tmp[ncols][D/2]; + __shared__ half2 KV_tmp_h2[kq_stride * (kq_nbatch/2 + 1)]; // Padded to avoid memory bank conflicts. + half2 VKQ[ncols/nwarps][D/(2*warp_size)] = {{{0.0f, 0.0f}}}; +#else + __shared__ float Q_tmp[ncols][D]; + __shared__ float KV_tmp_f[kq_stride * (kq_nbatch + 1)]; // Padded to avoid memory bank conflicts. + float2 * KV_tmp_f2 = (float2 *) KV_tmp_f; + float2 VKQ[ncols/nwarps][D/(2*warp_size)] = {{{0.0f, 0.0f}}}; +#endif // FAST_FP16_AVAILABLE + + + float kqmax[ncols/nwarps]; +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + kqmax[j0/nwarps] = -FLT_MAX/2.0f; + } + float kqsum[ncols/nwarps] = {0.0f}; + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { + const float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0 + threadIdx.x] : make_float2(0.0f, 0.0f); +#ifdef FAST_FP16_AVAILABLE + Q_tmp[j][i0 + threadIdx.x] = make_half2(tmp.x * scale, tmp.y * scale); +#else + Q_tmp[j][2*i0 + threadIdx.x] = tmp.x * scale; + Q_tmp[j][2*i0 + warp_size + threadIdx.x] = tmp.y * scale; +#endif // FAST_FP16_AVAILABLE + } + } + + __syncthreads(); + + const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; + for (int k_VKQ_0 = blockIdx.y*kq_stride; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*kq_stride) { + // Calculate KQ tile and keep track of new maximum KQ values: + + float kqmax_new[ncols/nwarps]; +#pragma unroll + for (int j = 0; j < ncols/nwarps; ++j) { + kqmax_new[j] = kqmax[j]; + } + + float sum[kq_stride/warp_size][ncols/nwarps] = {{0.0f}}; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += kq_nbatch) { +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += nwarps) { + const int i_KQ = i_KQ_0 + threadIdx.y; + +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += warp_size) { + const half2 tmp_h2 = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx.x]; +#ifdef FAST_FP16_AVAILABLE + KV_tmp_h2[i_KQ*(kq_nbatch/2 + 1) + k_KQ_1 + threadIdx.x] = tmp_h2; +#else + const float2 tmp_f2 = __half22float2(tmp_h2); + KV_tmp_f[i_KQ*(kq_nbatch + 1) + 2*k_KQ_1 + threadIdx.x] = tmp_f2.x; + KV_tmp_f[i_KQ*(kq_nbatch + 1) + 2*k_KQ_1 + warp_size + threadIdx.x] = tmp_f2.y; +#endif // FAST_FP16_AVAILABLE + } + } + + __syncthreads(); + +#ifdef FAST_FP16_AVAILABLE +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; ++k_KQ_1) { + half2 K_k[kq_stride/warp_size]; + half2 Q_k[ncols/nwarps]; +#else +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; ++k_KQ_1) { + float K_k[kq_stride/warp_size]; + float Q_k[ncols/nwarps]; +#endif // FAST_FP16_AVAILABLE + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) { + const int i_KQ = i_KQ_0 + threadIdx.x; + +#ifdef FAST_FP16_AVAILABLE + K_k[i_KQ_0/warp_size] = KV_tmp_h2[i_KQ*(kq_nbatch/2 + 1) + k_KQ_1]; +#else + K_k[i_KQ_0/warp_size] = KV_tmp_f [i_KQ*(kq_nbatch + 1) + k_KQ_1]; +#endif // FAST_FP16_AVAILABLE + } +#pragma unroll + for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { + const int j_KQ = j_KQ_0 + threadIdx.y; + +#ifdef FAST_FP16_AVAILABLE + Q_k[j_KQ_0/nwarps] = Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]; +#else + Q_k[j_KQ_0/nwarps] = Q_tmp[j_KQ][k_KQ_0 + k_KQ_1]; +#endif // FAST_FP16_AVAILABLE + } + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) { +#pragma unroll + for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { +#ifdef FAST_FP16_AVAILABLE + const float2 tmp = __half22float2(K_k[i_KQ_0/warp_size] * Q_k[j_KQ_0/nwarps]); + sum[i_KQ_0/warp_size][j_KQ_0/nwarps] += tmp.x + tmp.y; +#else + sum[i_KQ_0/warp_size][j_KQ_0/nwarps] += K_k[i_KQ_0/warp_size] * Q_k[j_KQ_0/nwarps]; +#endif // FAST_FP16_AVAILABLE + } + } + } + + if (k_KQ_0 + kq_nbatch < D) { + __syncthreads(); // Sync not needed on last iteration. + } + } + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) { + const int i_KQ = i_KQ_0 + threadIdx.x; + +#pragma unroll + for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { + const int j_KQ = j_KQ_0 + threadIdx.y; + + if (use_logit_softcap) { + sum[i_KQ_0/warp_size][j_KQ_0/nwarps] = logit_softcap * tanhf(sum[i_KQ_0/warp_size][j_KQ_0/nwarps]); + } + + sum[i_KQ_0/warp_size][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; + + kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/warp_size][j_KQ_0/nwarps]); + + KQ[j_KQ][i_KQ] = sum[i_KQ_0/warp_size][j_KQ_0/nwarps]; + } + } + + __syncthreads(); + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + kqmax_new[j0/nwarps] = warp_reduce_max(kqmax_new[j0/nwarps]); + const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]); + kqmax[j0/nwarps] = kqmax_new[j0/nwarps]; + + float kqsum_add = 0.0f; +#pragma unroll + for (int i0 = 0; i0 < kq_stride; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + const float diff = KQ[j][i] - kqmax[j0/nwarps]; + const float val = expf(diff); + kqsum_add += val; + KQ[j][i] = val; + } + kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add; + +#ifdef FAST_FP16_AVAILABLE + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { + VKQ[j0/nwarps][i0/warp_size] *= KQ_max_scale_h2; + } +#else +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { + VKQ[j0/nwarps][i0/warp_size].x *= KQ_max_scale; + VKQ[j0/nwarps][i0/warp_size].y *= KQ_max_scale; + } +#endif // FAST_FP16_AVAILABLE + } + + constexpr int V_cols_per_iter = kq_stride*kq_nbatch / D; + static_assert(kq_stride % V_cols_per_iter == 0, "bad V_cols_per_iter"); +#pragma unroll + for (int k0 = 0; k0 < kq_stride; k0 += V_cols_per_iter) { +#pragma unroll + for (int k1 = 0; k1 < V_cols_per_iter; k1 += nwarps) { + const int k_tile = k1 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { + const int i = i0 + threadIdx.x; + + const half2 tmp = V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i]; +#ifdef FAST_FP16_AVAILABLE + KV_tmp_h2[k_tile*(D/2) + i] = tmp; +#else + KV_tmp_f2[k_tile*(D/2) + i] = __half22float2(tmp); +#endif // FAST_FP16_AVAILABLE + } + } + + __syncthreads(); + +#pragma unroll + for (int k1 = 0; k1 < V_cols_per_iter; ++k1) { +#ifdef FAST_FP16_AVAILABLE + half2 V_k[(D/2)/warp_size]; + half2 KQ_k[ncols/nwarps]; +#else + float2 V_k[(D/2)/warp_size]; + float KQ_k[ncols/nwarps]; +#endif // FAST_FP16_AVAILABLE + +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { + const int i = i0 + threadIdx.x; + +#ifdef FAST_FP16_AVAILABLE + V_k[i0/warp_size] = KV_tmp_h2[k1*(D/2) + i]; +#else + V_k[i0/warp_size] = KV_tmp_f2[k1*(D/2) + i]; +#endif // FAST_FP16_AVAILABLE + } +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#ifdef FAST_FP16_AVAILABLE + const float tmp = KQ[j][k0 + k1]; + KQ_k[j0/nwarps] = make_half2(tmp, tmp); +#else + KQ_k[j0/nwarps] = KQ[j][k0 + k1]; +#endif // FAST_FP16_AVAILABLE + } + +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { +#ifdef FAST_FP16_AVAILABLE + VKQ[j0/nwarps][i0/warp_size] += V_k[i0/warp_size] *KQ_k[j0/nwarps]; +#else + VKQ[j0/nwarps][i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[j0/nwarps]; + VKQ[j0/nwarps][i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[j0/nwarps]; +#endif // FAST_FP16_AVAILABLE + } + } + } + + __syncthreads(); + } + } + + + // Attention sink: adjust running max and sum once per head + if (sinksf && blockIdx.y == 0) { + const float sink = sinksf[head]; + +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + float kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink); + kqmax_new_j = warp_reduce_max(kqmax_new_j); + + const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new_j); + kqmax[j0/nwarps] = kqmax_new_j; + + const float val = expf(sink - kqmax[j0/nwarps]); + kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale; + if (threadIdx.x == 0) { + kqsum[j0/nwarps] += val; + } + +#ifdef FAST_FP16_AVAILABLE + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { + VKQ[j0/nwarps][i0/warp_size] *= KQ_max_scale_h2; + } +#else +#pragma unroll + for (int i0 = 0; i0 < D/2; i0 += warp_size) { + VKQ[j0/nwarps][i0/warp_size].x *= KQ_max_scale; + VKQ[j0/nwarps][i0/warp_size].y *= KQ_max_scale; + } +#endif // FAST_FP16_AVAILABLE + } + } + + float2 * dst2 = (float2 *) dst; + +#pragma unroll + for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) { + const int j_VKQ = j_VKQ_0 + threadIdx.y; + + if (ic0 + j_VKQ >= ne01) { + return; + } + + float kqsum_j = kqsum[j_VKQ_0/nwarps]; + kqsum_j = warp_reduce_sum(kqsum_j); + + const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; + +#pragma unroll + for (int i00 = 0; i00 < D/2; i00 += warp_size) { + const int i0 = i00 + threadIdx.x; + +#ifdef FAST_FP16_AVAILABLE + float2 dst_val = __half22float2(VKQ[j_VKQ_0/nwarps][i0/warp_size]); +#else + float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/warp_size]; +#endif // FAST_FP16_AVAILABLE + + if (gridDim.y == 1) { + dst_val.x /= kqsum_j; + dst_val.y /= kqsum_j; + } + dst2[j_dst_unrolled*(D/2) + i0] = dst_val; + } + + if (gridDim.y != 1 && threadIdx.x == 0) { + dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); + } + } +#else + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); + NO_DEVICE_CODE; +#endif // FLASH_ATTN_AVAILABLE +} + +template +static void launch_fattn_tile_switch_ncols(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const int warp_size = 32; + const int nwarps = FATTN_TILE_NTHREADS / warp_size; + + constexpr size_t nbytes_shared = 0; + + if (Q->ne[1] > 16) { + constexpr int cols_per_block = 32; + fattn_kernel_t fattn_kernel = flash_attn_tile; + const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size); + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size); + return; + } + + constexpr int cols_per_block = 16; + fattn_kernel_t fattn_kernel = flash_attn_tile; + const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size); + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size); +} + +template +static void launch_fattn_tile_switch_head_size(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + switch (Q->ne[0]) { + case 64: { + launch_fattn_tile_switch_ncols< 64, use_logit_softcap>(ctx, dst); + } break; + case 128: { + launch_fattn_tile_switch_ncols<128, use_logit_softcap>(ctx, dst); + } break; + case 256: { + launch_fattn_tile_switch_ncols<256, use_logit_softcap>(ctx, dst); + } break; + default: { + GGML_ABORT("Unsupported head size"); + } break; + } +} + +void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + launch_fattn_tile_switch_head_size(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + launch_fattn_tile_switch_head_size(ctx, dst); + } +} diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh new file mode 100644 index 00000000..10dc22d1 --- /dev/null +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 48834272..7626d89c 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -1,8 +1,7 @@ #include "common.cuh" #include "fattn-common.cuh" #include "fattn-mma-f16.cuh" -#include "fattn-tile-f16.cuh" -#include "fattn-tile-f32.cuh" +#include "fattn-tile.cuh" #include "fattn-vec-f16.cuh" #include "fattn-vec-f32.cuh" #include "fattn-wmma-f16.cuh" @@ -271,8 +270,7 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg // Best FlashAttention kernel for a specific GPU: enum best_fattn_kernel { BEST_FATTN_KERNEL_NONE = 0, - BEST_FATTN_KERNEL_TILE_F32 = 200, - BEST_FATTN_KERNEL_TILE_F16 = 210, + BEST_FATTN_KERNEL_TILE = 200, BEST_FATTN_KERNEL_VEC_F32 = 100, BEST_FATTN_KERNEL_VEC_F16 = 110, BEST_FATTN_KERNEL_WMMA_F16 = 300, @@ -411,10 +409,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } // If there is no suitable kernel for tensor cores or small batch sizes, use the generic kernel for large batch sizes: - if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { - return BEST_FATTN_KERNEL_TILE_F16; - } - return BEST_FATTN_KERNEL_TILE_F32; + return BEST_FATTN_KERNEL_TILE; } void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { @@ -422,11 +417,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst switch (ggml_cuda_get_best_fattn_kernel(ggml_cuda_get_device(), dst)) { case BEST_FATTN_KERNEL_NONE: GGML_ABORT("fatal error"); - case BEST_FATTN_KERNEL_TILE_F32: - ggml_cuda_flash_attn_ext_tile_f32(ctx, dst); - break; - case BEST_FATTN_KERNEL_TILE_F16: - ggml_cuda_flash_attn_ext_tile_f16(ctx, dst); + case BEST_FATTN_KERNEL_TILE: + ggml_cuda_flash_attn_ext_tile(ctx, dst); break; case BEST_FATTN_KERNEL_VEC_F32: ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);