]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: fix Volta FlashAttention logic (llama/11615)
authorJohannes Gäßler <redacted>
Mon, 3 Feb 2025 12:25:56 +0000 (13:25 +0100)
committerGeorgi Gerganov <redacted>
Mon, 3 Feb 2025 20:00:57 +0000 (22:00 +0200)
ggml/src/ggml-cuda/fattn-wmma-f16.cu
ggml/src/ggml-cuda/fattn.cu

index 1054ff95d2296b222fab5bdccbca5a40352052d8..45702ad651fe67e906d35832b86863220c81e191 100644 (file)
@@ -561,7 +561,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
                     ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
                     break;
                 // case 256:
-                //     ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
+                //     ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
                 //     break;
                 default:
                     GGML_ABORT("fatal error");
index b1e66d470832c79c2827dcc5dcf52a064c21e9a3..b0cf152f52cf1bb5449342b1a3c95ffe22e082c3 100644 (file)
@@ -235,7 +235,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
         return;
     }
 
-    if (!new_mma_available(cc)) {
+    if (!fp16_mma_available(cc)) {
         if (prec == GGML_PREC_DEFAULT) {
             if (Q->ne[1] <= 8) {
                 ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
@@ -265,6 +265,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     // The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
     if (cc == GGML_CUDA_CC_VOLTA) {
         ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
+        return;
     }
 
     ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);