]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: refactor FA support/selection code (llama/15454)
authorJohannes Gäßler <redacted>
Wed, 20 Aug 2025 21:14:14 +0000 (23:14 +0200)
committerGeorgi Gerganov <redacted>
Fri, 5 Sep 2025 09:53:59 +0000 (12:53 +0300)
src/ggml-cuda/fattn-common.cuh
src/ggml-cuda/fattn.cu
src/ggml-cuda/fattn.cuh
src/ggml-cuda/ggml-cuda.cu

index d4ed938391b478f7c1560cb82ce48ff1f37e4eef..b69f57d659a266dbd626adf93640df63b672c086 100644 (file)
@@ -704,28 +704,6 @@ static __global__ void flash_attn_combine_results(
     dst[tid] = VKQ_numerator / VKQ_denominator;
 }
 
-[[noreturn]]
-static void on_no_fattn_vec_case(const int D) {
-    if (D == 64) {
-        fprintf(stderr, "Unsupported KV type combination for head_size 64.\n");
-        fprintf(stderr, "By default only f16 KV cache is supported.\n");
-        fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for V cache quantization support.\n");
-        GGML_ABORT("fatal error");
-    } else if (D == 128) {
-        fprintf(stderr, "Unsupported KV type combination for head_size 128.\n");
-        fprintf(stderr, "Supported combinations:\n");
-        fprintf(stderr, "  - K == q4_0, V == q4_0,  4.50 BPV\n");
-        fprintf(stderr, "  - K == q8_0, V == q8_0,  8.50 BPV\n");
-        fprintf(stderr, "  - K == f16,  V == f16,  16.00 BPV\n");
-        fprintf(stderr, "Compile with GGML_CUDA_FA_ALL_QUANTS for all combinations of q4_0, q4_1, q5_0, q5_1, q8_0, and f16.\n");
-        GGML_ABORT("fatal error");
-    } else {
-        fprintf(stderr, "Unsupported KV type combination for head_size %d.\n", D);
-        fprintf(stderr, "Only f16 is supported.\n");
-        GGML_ABORT("fatal error");
-    }
-}
-
 template <int DV, int ncols1, int ncols2>
 void launch_fattn(
     ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared,
index 22e90d0e7b31611da2c55e2849f91609e0ee49bc..48834272660e51ea7e87d49ff036be85c249f6fa 100644 (file)
@@ -190,7 +190,7 @@ static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, gg
     FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
 #endif // GGML_CUDA_FA_ALL_QUANTS
 
-    on_no_fattn_vec_case(Q->ne[0]);
+    GGML_ABORT("fatal error");
 }
 
 #define FATTN_VEC_F32_CASE(D, type_K, type_V)                               \
@@ -265,74 +265,184 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
     FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16)
 #endif // GGML_CUDA_FA_ALL_QUANTS
 
-    on_no_fattn_vec_case(Q->ne[0]);
+    GGML_ABORT("fatal error");
 }
 
-void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+// Best FlashAttention kernel for a specific GPU:
+enum best_fattn_kernel {
+    BEST_FATTN_KERNEL_NONE     =   0,
+    BEST_FATTN_KERNEL_TILE_F32 = 200,
+    BEST_FATTN_KERNEL_TILE_F16 = 210,
+    BEST_FATTN_KERNEL_VEC_F32  = 100,
+    BEST_FATTN_KERNEL_VEC_F16  = 110,
+    BEST_FATTN_KERNEL_WMMA_F16 = 300,
+    BEST_FATTN_KERNEL_MMA_F16  = 400,
+};
+
+static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const ggml_tensor * dst) {
+#ifndef FLASH_ATTN_AVAILABLE
+    GGML_UNUSED(device); GGML_UNUSED(dst);
+    return BEST_FATTN_KERNEL_NONE;
+#endif// FLASH_ATTN_AVAILABLE
+
     const ggml_tensor * KQV   = dst;
     const ggml_tensor * Q     = dst->src[0];
     const ggml_tensor * K     = dst->src[1];
     const ggml_tensor * V     = dst->src[2];
     const ggml_tensor * mask  = dst->src[3];
 
-    ggml_cuda_set_device(ctx.device);
-    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
-    const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
+    const int gqa_ratio = Q->ne[2] / K->ne[2];
+    GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
+
+    const int cc = ggml_cuda_info().devices[device].cc;
+    const int warp_size = ggml_cuda_info().devices[device].warp_size;
     const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
 
-#if defined(GGML_HIP_ROCWMMA_FATTN)
-    if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) {
-        ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
-        return;
+    switch (K->ne[0]) {
+        case  64:
+        case 128:
+        case 256:
+            if (V->ne[0] != K->ne[0]) {
+                return BEST_FATTN_KERNEL_NONE;
+            }
+            break;
+        case  80:
+        case  96:
+        case 112:
+            if (V->ne[0] != K->ne[0]) {
+                return BEST_FATTN_KERNEL_NONE;
+            }
+            if (!fp16_mma_available(cc) && !turing_mma_available(cc)) {
+                return BEST_FATTN_KERNEL_NONE;
+            }
+            break;
+        case 576:
+            if (V->ne[0] != 512) {
+                return BEST_FATTN_KERNEL_NONE;
+            }
+            if (!turing_mma_available(cc) || gqa_ratio % 16 != 0) {
+                return BEST_FATTN_KERNEL_NONE;
+            }
+            break;
+        default:
+            return BEST_FATTN_KERNEL_NONE;
     }
-#endif // defined(GGML_HIP_ROCWMMA_FATTN)
 
-    if (!fast_fp16_available(cc)) {
-        if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
-            ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
-        } else {
-            ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
-        }
-        return;
+#ifndef GGML_CUDA_FA_ALL_QUANTS
+    if (K->type != V->type) {
+        return BEST_FATTN_KERNEL_NONE;
     }
+#endif // GGML_CUDA_FA_ALL_QUANTS
 
-    if (!fp16_mma_available(cc)) {
-        if (prec == GGML_PREC_DEFAULT) {
-            if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
-                ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
-            } else {
-                ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
+    switch (K->type) {
+        case GGML_TYPE_F16:
+            break;
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+#ifndef GGML_CUDA_FA_ALL_QUANTS
+            return BEST_FATTN_KERNEL_NONE;
+#endif // GGML_CUDA_FA_ALL_QUANTS
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q8_0:
+#ifdef GGML_CUDA_FA_ALL_QUANTS
+            if (K->ne[0] != 128 && K->ne[0] != 64) {
+                return BEST_FATTN_KERNEL_NONE;
             }
-        } else {
-            if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
-                ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
-            } else {
-                ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
+#else
+            if (K->ne[0] != 128) {
+                return BEST_FATTN_KERNEL_NONE;
             }
-        }
-        return;
+#endif // GGML_CUDA_FA_ALL_QUANTS
+            break;
+        default:
+            return BEST_FATTN_KERNEL_NONE;
+    }
+
+    switch (V->type) {
+        case GGML_TYPE_F16:
+            break;
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q8_0:
+            if (K->ne[0] != 128) {
+                return BEST_FATTN_KERNEL_NONE;
+            }
+            break;
+        default:
+            return BEST_FATTN_KERNEL_NONE;
+    }
+
+    if (mask && mask->ne[2] != 1) {
+        return BEST_FATTN_KERNEL_NONE;
     }
 
-    const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
-    const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
-    const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (Q->ne[2] > 4*K->ne[2] && K->ne[1] >= 8192);
-    const bool mma_faster_for_bs1 = turing_mma_available(cc) && gqa_opt_applies && !mma_needs_data_conversion &&
-        (cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000);
     const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0;
-    if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
-        if (prec == GGML_PREC_DEFAULT) {
-            ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
-        } else {
-            ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
+
+    // If Turing tensor cores available, use them except for some cases with batch size 1:
+    if (turing_mma_available(cc)) {
+        const bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask; // The mma-based kernels have GQA-specific optimizations
+        const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
+        const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (gqa_ratio > 4 && K->ne[1] >= 8192);
+        const bool mma_faster_for_bs1 = gqa_opt_applies && !mma_needs_data_conversion &&
+            (cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000);
+        if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
+            if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
+                return BEST_FATTN_KERNEL_VEC_F16;
+            }
+            return BEST_FATTN_KERNEL_VEC_F32;
         }
-        return;
+        return BEST_FATTN_KERNEL_MMA_F16;
     }
 
-    // The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
-    if (fp16_mma_available(cc) && !turing_mma_available(cc)) {
-        ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
-        return;
+    // Use kernels specializes for small batch sizes if possible:
+    if (Q->ne[1] <= 8 && can_use_vector_kernel) {
+        if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
+            return BEST_FATTN_KERNEL_VEC_F16;
+        }
+        return BEST_FATTN_KERNEL_VEC_F32;
+    }
+
+    // For large batch sizes, use the WMMA kernel if possible:
+    if (fp16_mma_available(cc)) {
+        return BEST_FATTN_KERNEL_WMMA_F16;
+    }
+
+    // If there is no suitable kernel for tensor cores or small batch sizes, use the generic kernel for large batch sizes:
+    if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
+        return BEST_FATTN_KERNEL_TILE_F16;
     }
+    return BEST_FATTN_KERNEL_TILE_F32;
+}
+
+void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    ggml_cuda_set_device(ctx.device);
+    switch (ggml_cuda_get_best_fattn_kernel(ggml_cuda_get_device(), dst)) {
+        case BEST_FATTN_KERNEL_NONE:
+            GGML_ABORT("fatal error");
+        case BEST_FATTN_KERNEL_TILE_F32:
+            ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
+            break;
+        case BEST_FATTN_KERNEL_TILE_F16:
+            ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
+            break;
+        case BEST_FATTN_KERNEL_VEC_F32:
+            ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
+            break;
+        case BEST_FATTN_KERNEL_VEC_F16:
+            ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
+            break;
+        case BEST_FATTN_KERNEL_WMMA_F16:
+            ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
+            break;
+        case BEST_FATTN_KERNEL_MMA_F16:
+            ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
+            break;
+    }
+}
 
-    ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
+bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst) {
+    return ggml_cuda_get_best_fattn_kernel(device, dst) != BEST_FATTN_KERNEL_NONE;
 }
index ad3ca7a8d8e4d657385d373795c2e38f60d16450..78705d59951c1bf34f0a5fec3e85daca6a228c26 100644 (file)
@@ -1,3 +1,5 @@
 #include "common.cuh"
 
 void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+bool ggml_cuda_flash_attn_ext_supported(int device, const ggml_tensor * dst);
index 8b706752bc7a0a4f2fc3bd268da51f53a5c40f34..1440f2f2e94755bdf417e43284c8d437f5b6fb47 100644 (file)
@@ -3499,44 +3499,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_GATED_LINEAR_ATTN:
         case GGML_OP_RWKV_WKV7:
             return true;
-        case GGML_OP_FLASH_ATTN_EXT: {
-#ifndef FLASH_ATTN_AVAILABLE
-            return false;
-#endif // FLASH_ATTN_AVAILABLE
-            if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
-                const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
-                if (!turing_mma_available(cc)) {
-                    return false;
-                }
-                const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
-                return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
-            }
-            // TODO: more general-purpose attention sink support [TAG_ATTN_SINKS]
-            if (op->src[4] && !fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc)
-                    && op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) {
-                return false;
-            }
-            if (op->src[0]->ne[0] == 192) {
-                return false;
-            }
-            if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
-                return false;
-            }
-            if (op->src[0]->ne[0] ==  64 && op->src[1]->type == GGML_TYPE_F16) {
-                return true;
-            }
-            if (op->src[0]->ne[0] == 128) {
-                return true;
-            }
-            if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
-                return true;
-            }
-            if (op->src[3] && op->src[3]->ne[2] != 1) {
-                return false;
-            }
-            return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
-                op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
-        }
+        case GGML_OP_FLASH_ATTN_EXT:
+            return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op);
         case GGML_OP_CROSS_ENTROPY_LOSS:
         case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
         case GGML_OP_OPT_STEP_ADAMW: