]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: enable FA for FP32 KV cache (#16546)
authorJohannes Gäßler <redacted>
Tue, 14 Oct 2025 12:22:47 +0000 (14:22 +0200)
committerGitHub <redacted>
Tue, 14 Oct 2025 12:22:47 +0000 (14:22 +0200)
ggml/src/ggml-cuda/fattn-vec.cuh
ggml/src/ggml-cuda/fattn.cu

index 89ab0f1638bf7e5d071656182d985e7d46927b7b..e1838fddedc6dbf8a575cf861c6b530264891f30 100644 (file)
@@ -516,8 +516,8 @@ void ggml_cuda_flash_attn_ext_vec_case_impl(ggml_backend_cuda_context & ctx, ggm
     const int nthreads = ggml_cuda_fattn_vec_get_nthreads_host(cc);
     const int nwarps   = nthreads / WARP_SIZE;
     fattn_kernel_t fattn_kernel = flash_attn_ext_vec<D, cols_per_block, type_K, type_V, use_logit_softcap>;
-    constexpr bool need_f16_K = false;
-    constexpr bool need_f16_V = false;
+    const bool need_f16_K = type_K == GGML_TYPE_F16;
+    const bool need_f16_V = type_V == GGML_TYPE_F16;
     constexpr size_t nbytes_shared = 0;
     launch_fattn<D, cols_per_block, 1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
 }
@@ -526,11 +526,6 @@ template <int D, ggml_type type_K, ggml_type type_V>
 void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * KQV = dst;
     const ggml_tensor * Q   = dst->src[0];
-    const ggml_tensor * K   = dst->src[1];
-    const ggml_tensor * V   = dst->src[2];
-
-    GGML_ASSERT(K->type == type_K);
-    GGML_ASSERT(V->type == type_V);
 
     float logit_softcap;
     memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
index fe970adaecef3a523cd6409fc4da36f3389a77c9..7dee032c291373626910bab6719dbde5c325b265 100644 (file)
@@ -116,11 +116,15 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
     }
 }
 
-#define FATTN_VEC_CASE(D, type_K, type_V)                                \
-    if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
-        ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst);  \
-        return;                                                          \
-    }                                                                    \
+#define FATTN_VEC_CASE(D, type_K, type_V)                                                                        \
+    {                                                                                                            \
+        const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \
+        const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \
+        if (Q->ne[0] == (D) && type_K_okay && type_V_okay) {                                                     \
+            ggml_cuda_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst);                                      \
+            return;                                                                                              \
+        }                                                                                                        \
+    }                                                                                                            \
 
 #define FATTN_VEC_CASES_ALL_D(type_K, type_V) \
     FATTN_VEC_CASE( 64, type_K, type_V)       \
@@ -247,6 +251,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
 #endif // GGML_CUDA_FA_ALL_QUANTS
 
     switch (K->type) {
+        case GGML_TYPE_F32:
         case GGML_TYPE_F16:
             break;
         case GGML_TYPE_Q4_1:
@@ -272,7 +277,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
     // If Turing tensor cores available, use them:
     if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40) {
         if (can_use_vector_kernel) {
-            if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) {
+            if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
                 if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
                     return BEST_FATTN_KERNEL_VEC;
                 }
@@ -305,7 +310,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
 
     // If there are no tensor cores available, use the generic tile kernel:
     if (can_use_vector_kernel) {
-        if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) {
+        if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
             if (Q->ne[1] == 1) {
                 if (!gqa_opt_applies) {
                     return BEST_FATTN_KERNEL_VEC;