]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: use mma FA kernel for gqa > 4 on RTX 4000 (#15035)
authorJohannes Gäßler <redacted>
Sat, 2 Aug 2025 14:37:08 +0000 (16:37 +0200)
committerGitHub <redacted>
Sat, 2 Aug 2025 14:37:08 +0000 (16:37 +0200)
ggml/src/ggml-cuda/fattn.cu

index a51136f6b8aa954fa203b79c745d7e4e3ff9c10f..039c54e015ea67e69e5a7a2f5333185f393f2cec 100644 (file)
@@ -315,8 +315,9 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
 
     const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
     const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
-    const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies &&
-        (Q->ne[3] > 1 || cc < GGML_CUDA_CC_ADA_LOVELACE) && !mma_needs_data_conversion;
+    const bool mma_faster_for_rtx4000 = Q->ne[3] > 1 || (Q->ne[2] > 4*K->ne[2] && K->ne[1] >= 8192);
+    const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && !mma_needs_data_conversion &&
+        (cc < GGML_CUDA_CC_ADA_LOVELACE || mma_faster_for_rtx4000);
     const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0;
     if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
         if (prec == GGML_PREC_DEFAULT) {