From: Johannes Gäßler Date: Wed, 20 Aug 2025 21:14:14 +0000 (+0200) Subject: CUDA: refactor FA support/selection code (llama/15454) X-Git-Tag: v0.9.1~181 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=e6d142fd85aa7f75c7433e0f7214418ab0818d0c;p=pkg%2Fggml%2Fsources%2Fggml CUDA: refactor FA support/selection code (llama/15454) --- diff --git a/src/ggml-cuda/fattn-common.cuh b/src/ggml-cuda/fattn-common.cuh index d4ed9383..b69f57d6 100644 --- a/src/ggml-cuda/fattn-common.cuh +++ b/src/ggml-cuda/fattn-common.cuh @@ -704,28 +704,6 @@ static __global__ void flash_attn_combine_results( dst[tid] = VKQ_numerator / VKQ_denominator; } -[[noreturn]] -static void on_no_fattn_vec_case(const int D) { - if (D == 64) { - fprintf(stderr, "Unsupported KV type combination for head_size 64.\n"); - fprintf(stderr, "By default only f16 KV cache is supported.\n"); - fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for V cache quantization support.\n"); - GGML_ABORT("fatal error"); - } else if (D == 128) { - fprintf(stderr, "Unsupported KV type combination for head_size 128.\n"); - fprintf(stderr, "Supported combinations:\n"); - fprintf(stderr, " - K == q4_0, V == q4_0, 4.50 BPV\n"); - fprintf(stderr, " - K == q8_0, V == q8_0, 8.50 BPV\n"); - fprintf(stderr, " - K == f16, V == f16, 16.00 BPV\n"); - fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n"); - GGML_ABORT("fatal error"); - } else { - fprintf(stderr, "Unsupported KV type combination for head_size %d.\n", D); - fprintf(stderr, "Only f16 is supported.\n"); - GGML_ABORT("fatal error"); - } -} - template void launch_fattn( ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared, diff --git a/src/ggml-cuda/fattn.cu b/src/ggml-cuda/fattn.cu index 22e90d0e..48834272 100644 --- a/src/ggml-cuda/fattn.cu +++ b/src/ggml-cuda/fattn.cu @@ -190,7 +190,7 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) #endif // GGML_CUDA_FA_ALL_QUANTS - on_no_fattn_vec_case(Q->ne[0]); + GGML_ABORT("fatal error"); } #define FATTN_VEC_F32_CASE(D, type_K, type_V) \ @@ -265,74 +265,184 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) #endif // GGML_CUDA_FA_ALL_QUANTS - on_no_fattn_vec_case(Q->ne[0]); + GGML_ABORT("fatal error"); } -void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { +// Best FlashAttention kernel for a specific GPU: +enum best_fattn_kernel { + BEST_FATTN_KERNEL_NONE = 0, + BEST_FATTN_KERNEL_TILE_F32 = 200, + BEST_FATTN_KERNEL_TILE_F16 = 210, + BEST_FATTN_KERNEL_VEC_F32 = 100, + BEST_FATTN_KERNEL_VEC_F16 = 110, + BEST_FATTN_KERNEL_WMMA_F16 = 300, + BEST_FATTN_KERNEL_MMA_F16 = 400, +}; + +static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const ggml_tensor * dst) { +#ifndef FLASH_ATTN_AVAILABLE + GGML_UNUSED(device); GGML_UNUSED(dst); + return BEST_FATTN_KERNEL_NONE; +#endif// FLASH_ATTN_AVAILABLE + const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; const ggml_tensor * mask = dst->src[3]; - ggml_cuda_set_device(ctx.device); - const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; - const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; + const int gqa_ratio = Q->ne[2] / K->ne[2]; + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + + const int cc = ggml_cuda_info().devices[device].cc; + const int warp_size = ggml_cuda_info().devices[device].warp_size; const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); -#if defined(GGML_HIP_ROCWMMA_FATTN) - if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) { - ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); - return; + switch (K->ne[0]) { + case 64: + case 128: + case 256: + if (V->ne[0] != K->ne[0]) { + return BEST_FATTN_KERNEL_NONE; + } + break; + case 80: + case 96: + case 112: + if (V->ne[0] != K->ne[0]) { + return BEST_FATTN_KERNEL_NONE; + } + if (!fp16_mma_available(cc) && !turing_mma_available(cc)) { + return BEST_FATTN_KERNEL_NONE; + } + break; + case 576: + if (V->ne[0] != 512) { + return BEST_FATTN_KERNEL_NONE; + } + if (!turing_mma_available(cc) || gqa_ratio % 16 != 0) { + return BEST_FATTN_KERNEL_NONE; + } + break; + default: + return BEST_FATTN_KERNEL_NONE; } -#endif // defined(GGML_HIP_ROCWMMA_FATTN) - if (!fast_fp16_available(cc)) { - if (Q->ne[1] <= 8 || Q->ne[0] == 256) { - ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); - } else { - ggml_cuda_flash_attn_ext_tile_f32(ctx, dst); - } - return; +#ifndef GGML_CUDA_FA_ALL_QUANTS + if (K->type != V->type) { + return BEST_FATTN_KERNEL_NONE; } +#endif // GGML_CUDA_FA_ALL_QUANTS - if (!fp16_mma_available(cc)) { - if (prec == GGML_PREC_DEFAULT) { - if (Q->ne[1] <= 8 || Q->ne[0] == 256) { - ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); - } else { - ggml_cuda_flash_attn_ext_tile_f16(ctx, dst); + switch (K->type) { + case GGML_TYPE_F16: + break; + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: +#ifndef GGML_CUDA_FA_ALL_QUANTS + return BEST_FATTN_KERNEL_NONE; +#endif // GGML_CUDA_FA_ALL_QUANTS + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: +#ifdef GGML_CUDA_FA_ALL_QUANTS + if (K->ne[0] != 128 && K->ne[0] != 64) { + return BEST_FATTN_KERNEL_NONE; } - } else { - if (Q->ne[1] <= 8 || Q->ne[0] == 256) { - ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); - } else { - ggml_cuda_flash_attn_ext_tile_f32(ctx, dst); +#else + if (K->ne[0] != 128) { + return BEST_FATTN_KERNEL_NONE; } - } - return; +#endif // GGML_CUDA_FA_ALL_QUANTS + break; + default: + return BEST_FATTN_KERNEL_NONE; + } + + switch (V->type) { + case GGML_TYPE_F16: + break; + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: + if (K->ne[0] != 128) { + return BEST_FATTN_KERNEL_NONE; + } + break; + default: + return BEST_FATTN_KERNEL_NONE; + } + + if (mask && mask->ne[2] != 1) { + return BEST_FATTN_KERNEL_NONE; } - const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations - const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16; - const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (Q->ne[2] > 4*K->ne[2] && K->ne[1] >= 8192); - const bool mma_faster_for_bs1 = turing_mma_available(cc) && gqa_opt_applies && !mma_needs_data_conversion && - (cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000); const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0; - if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) { - if (prec == GGML_PREC_DEFAULT) { - ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); - } else { - ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); + + // If Turing tensor cores available, use them except for some cases with batch size 1: + if (turing_mma_available(cc)) { + const bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask; // The mma-based kernels have GQA-specific optimizations + const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16; + const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (gqa_ratio > 4 && K->ne[1] >= 8192); + const bool mma_faster_for_bs1 = gqa_opt_applies && !mma_needs_data_conversion && + (cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000); + if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) { + if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { + return BEST_FATTN_KERNEL_VEC_F16; + } + return BEST_FATTN_KERNEL_VEC_F32; } - return; + return BEST_FATTN_KERNEL_MMA_F16; } - // The MMA implementation needs Turing or newer, use the old WMMA code for Volta: - if (fp16_mma_available(cc) && !turing_mma_available(cc)) { - ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); - return; + // Use kernels specializes for small batch sizes if possible: + if (Q->ne[1] <= 8 && can_use_vector_kernel) { + if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { + return BEST_FATTN_KERNEL_VEC_F16; + } + return BEST_FATTN_KERNEL_VEC_F32; + } + + // For large batch sizes, use the WMMA kernel if possible: + if (fp16_mma_available(cc)) { + return BEST_FATTN_KERNEL_WMMA_F16; + } + + // If there is no suitable kernel for tensor cores or small batch sizes, use the generic kernel for large batch sizes: + if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { + return BEST_FATTN_KERNEL_TILE_F16; } + return BEST_FATTN_KERNEL_TILE_F32; +} + +void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + ggml_cuda_set_device(ctx.device); + switch (ggml_cuda_get_best_fattn_kernel(ggml_cuda_get_device(), dst)) { + case BEST_FATTN_KERNEL_NONE: + GGML_ABORT("fatal error"); + case BEST_FATTN_KERNEL_TILE_F32: + ggml_cuda_flash_attn_ext_tile_f32(ctx, dst); + break; + case BEST_FATTN_KERNEL_TILE_F16: + ggml_cuda_flash_attn_ext_tile_f16(ctx, dst); + break; + case BEST_FATTN_KERNEL_VEC_F32: + ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); + break; + case BEST_FATTN_KERNEL_VEC_F16: + ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); + break; + case BEST_FATTN_KERNEL_WMMA_F16: + ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); + break; + case BEST_FATTN_KERNEL_MMA_F16: + ggml_cuda_flash_attn_ext_mma_f16(ctx, dst); + break; + } +} - ggml_cuda_flash_attn_ext_mma_f16(ctx, dst); +bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst) { + return ggml_cuda_get_best_fattn_kernel(device, dst) != BEST_FATTN_KERNEL_NONE; } diff --git a/src/ggml-cuda/fattn.cuh b/src/ggml-cuda/fattn.cuh index ad3ca7a8..78705d59 100644 --- a/src/ggml-cuda/fattn.cuh +++ b/src/ggml-cuda/fattn.cuh @@ -1,3 +1,5 @@ #include "common.cuh" void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst); diff --git a/src/ggml-cuda/ggml-cuda.cu b/src/ggml-cuda/ggml-cuda.cu index 8b706752..1440f2f2 100644 --- a/src/ggml-cuda/ggml-cuda.cu +++ b/src/ggml-cuda/ggml-cuda.cu @@ -3499,44 +3499,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_GATED_LINEAR_ATTN: case GGML_OP_RWKV_WKV7: return true; - case GGML_OP_FLASH_ATTN_EXT: { -#ifndef FLASH_ATTN_AVAILABLE - return false; -#endif // FLASH_ATTN_AVAILABLE - if (op->src[1]->ne[0] != op->src[2]->ne[0]) { - const int cc = ggml_cuda_info().devices[dev_ctx->device].cc; - if (!turing_mma_available(cc)) { - return false; - } - const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2]; - return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0; - } - // TODO: more general-purpose attention sink support [TAG_ATTN_SINKS] - if (op->src[4] && !fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) - && op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) { - return false; - } - if (op->src[0]->ne[0] == 192) { - return false; - } - if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) { - return false; - } - if (op->src[0]->ne[0] == 64 && op->src[1]->type == GGML_TYPE_F16) { - return true; - } - if (op->src[0]->ne[0] == 128) { - return true; - } - if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) { - return true; - } - if (op->src[3] && op->src[3]->ne[2] != 1) { - return false; - } - return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) && - op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16; - } + case GGML_OP_FLASH_ATTN_EXT: + return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op); case GGML_OP_CROSS_ENTROPY_LOSS: case GGML_OP_CROSS_ENTROPY_LOSS_BACK: case GGML_OP_OPT_STEP_ADAMW: