]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: app option to compile without FlashAttention (llama/12025)
authorJohannes Gäßler <redacted>
Sat, 22 Feb 2025 19:44:34 +0000 (20:44 +0100)
committerGeorgi Gerganov <redacted>
Thu, 27 Feb 2025 06:55:36 +0000 (08:55 +0200)
12 files changed:
ggml/CMakeLists.txt
ggml/src/ggml-cuda/CMakeLists.txt
ggml/src/ggml-cuda/common.cuh
ggml/src/ggml-cuda/fattn-mma-f16.cuh
ggml/src/ggml-cuda/fattn-tile-f16.cu
ggml/src/ggml-cuda/fattn-tile-f32.cu
ggml/src/ggml-cuda/fattn-vec-f16.cuh
ggml/src/ggml-cuda/fattn-vec-f32.cuh
ggml/src/ggml-cuda/fattn-wmma-f16.cu
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-hip/CMakeLists.txt
ggml/src/ggml-musa/CMakeLists.txt

index fc5eac151b90cbd6ef7e04c4c80d0840a3b515af..12afe0f25a87bb63d60bfcc97f33e539b6c5cf48 100644 (file)
@@ -151,6 +151,7 @@ set   (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
                                             "ggml: max. batch size for using peer access")
 option(GGML_CUDA_NO_PEER_COPY               "ggml: do not use peer to peer copies"            OFF)
 option(GGML_CUDA_NO_VMM                     "ggml: do not try to use CUDA VMM"                OFF)
+option(GGML_CUDA_FA                         "ggml: compile ggml FlashAttention CUDA kernels"  ON)
 option(GGML_CUDA_FA_ALL_QUANTS              "ggml: compile all quants for FlashAttention"     OFF)
 option(GGML_CUDA_GRAPHS                     "ggml: use CUDA graphs (llama.cpp only)"          ${GGML_CUDA_GRAPHS_DEFAULT})
 
index e63ede2fbe3ffab172f635d28506050c905477f0..96bd5a0be297663d882eb1a22feb5a79f5180910 100644 (file)
@@ -69,6 +69,10 @@ if (CUDAToolkit_FOUND)
         add_compile_definitions(GGML_CUDA_NO_VMM)
     endif()
 
+    if (NOT GGML_CUDA_FA)
+        add_compile_definitions(GGML_CUDA_NO_FA)
+    endif()
+
     if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
         add_compile_definitions(GGML_CUDA_F16)
     endif()
index 7e99838c09261ec7069dcbe3a1deba690ab75735..adf0d3ecb566cc78bfd29726b1cc929ed0a66bf9 100644 (file)
@@ -204,9 +204,9 @@ typedef float2 dfloat2;
 #define CP_ASYNC_AVAILABLE
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 
-#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
+#if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
 #define FLASH_ATTN_AVAILABLE
-#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
+#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
 
 static bool fp16_available(const int cc) {
     return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
index b2e0db9a2cc254ecf318da319d4e76bfec446f77..718ee5402dccd16e8a80ae8a8c57ed90827cef20 100644 (file)
@@ -839,10 +839,7 @@ static __global__ void flash_attn_ext_f16(
         const int ne1,
         const int ne2,
         const int ne3) {
-#ifndef NEW_MMA_AVAILABLE
-    NO_DEVICE_CODE;
-    return;
-#endif // NEW_MMA_AVAILABLE
+#if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
 
     // Skip unused kernel variants for faster compilation:
     if (use_logit_softcap && !(D == 128 || D == 256)) {
@@ -933,6 +930,9 @@ static __global__ void flash_attn_ext_f16(
     flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
         (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
          ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
+#else
+    NO_DEVICE_CODE;
+#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
 }
 
 template <int D, int ncols1, int ncols2>
index b8b415effb7e181c02d2a204e3b5b7eb6d4c9c61..ef3569fab27892ef9ef4cf07208dfee5d6c6136c 100644 (file)
@@ -44,12 +44,7 @@ static __global__ void flash_attn_tile_ext_f16(
         const int ne1,
         const int ne2,
         const int ne3) {
-#ifdef FP16_AVAILABLE
-
-#ifndef FLASH_ATTN_AVAILABLE
-    NO_DEVICE_CODE;
-    return;
-#endif // FLASH_ATTN_AVAILABLE
+#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
 
     // Skip unused kernel variants for faster compilation:
 #ifdef FP16_MMA_AVAILABLE
@@ -290,7 +285,7 @@ static __global__ void flash_attn_tile_ext_f16(
     }
 #else
    NO_DEVICE_CODE;
-#endif // FP16_AVAILABLE
+#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
 }
 
 template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
index 4352a284464764800f0528589a610a6012661238..04b69c83be0488bcae89e6583c12cd3856409f44 100644 (file)
@@ -44,10 +44,7 @@ static __global__ void flash_attn_tile_ext_f32(
         const int ne1,
         const int ne2,
         const int ne3) {
-#ifndef FLASH_ATTN_AVAILABLE
-    NO_DEVICE_CODE;
-    return;
-#endif // FLASH_ATTN_AVAILABLE
+#ifdef FLASH_ATTN_AVAILABLE
 
     // Skip unused kernel variants for faster compilation:
 #ifdef FP16_MMA_AVAILABLE
@@ -285,6 +282,9 @@ static __global__ void flash_attn_tile_ext_f32(
             dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
         }
     }
+#else
+    NO_DEVICE_CODE;
+#endif // FLASH_ATTN_AVAILABLE
 }
 
 template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
index e758a0f6ec276eca1e6f831e09e499ee88e9ef99..b7686c1ec3d47ef8cacf05f5487857f03cd51ed9 100644 (file)
@@ -41,12 +41,7 @@ static __global__ void flash_attn_vec_ext_f16(
         const int ne1,
         const int ne2,
         const int ne3) {
-#ifdef FP16_AVAILABLE
-
-#ifndef FLASH_ATTN_AVAILABLE
-    NO_DEVICE_CODE;
-    return;
-#endif // FLASH_ATTN_AVAILABLE
+#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
 
     // Skip unused kernel variants for faster compilation:
     if (use_logit_softcap && !(D == 128 || D == 256)) {
@@ -300,7 +295,7 @@ static __global__ void flash_attn_vec_ext_f16(
     }
 #else
    NO_DEVICE_CODE;
-#endif // FP16_AVAILABLE
+#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
 }
 
 template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
index 134144a383ffaaa31bf5ba7acf3643d880d73d59..c1d2dd8d19f4d8c5d53351ab4a5235d34bc878e8 100644 (file)
@@ -41,10 +41,7 @@ static __global__ void flash_attn_vec_ext_f32(
         const int ne1,
         const int ne2,
         const int ne3) {
-#ifndef FLASH_ATTN_AVAILABLE
-    NO_DEVICE_CODE;
-    return;
-#endif // FLASH_ATTN_AVAILABLE
+#ifdef FLASH_ATTN_AVAILABLE
 
     // Skip unused kernel variants for faster compilation:
     if (use_logit_softcap && !(D == 128 || D == 256)) {
@@ -281,6 +278,9 @@ static __global__ void flash_attn_vec_ext_f32(
     if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
         dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
     }
+#else
+    NO_DEVICE_CODE;
+#endif // FLASH_ATTN_AVAILABLE
 }
 
 template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
index de38470abec456ae3a23a91a13e90ff3ee1b8cd6..8828652fb5e7f1c7437650c27adafacae332e917 100644 (file)
@@ -51,7 +51,7 @@ static __global__ void flash_attn_ext_f16(
         const int ne1,
         const int ne2,
         const int ne3) {
-#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+#if defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
     // Skip unused kernel variants for faster compilation:
     if (use_logit_softcap && !(D == 128 || D == 256)) {
         NO_DEVICE_CODE;
@@ -425,7 +425,7 @@ static __global__ void flash_attn_ext_f16(
     }
 #else
    NO_DEVICE_CODE;
-#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
+#endif // defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
 }
 
 constexpr int get_max_power_of_2(int x) {
index f685423215ba1f5d8bd8d32218750ee11b3a427f..ebb2ccae04065505b73e013fcbde57ba1a26ee94 100644 (file)
@@ -3203,7 +3203,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_FLASH_ATTN_EXT: {
 #ifndef FLASH_ATTN_AVAILABLE
             return false;
-#endif
+#endif // FLASH_ATTN_AVAILABLE
             if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
                 return false;
             }
index f4a4683639fab51897b66beca7bb155d4da5ac53..4a0384dd47654ad3e507640ca423fedd7218898c 100644 (file)
@@ -107,6 +107,10 @@ if (GGML_HIP_NO_VMM)
     add_compile_definitions(GGML_HIP_NO_VMM)
 endif()
 
+if (NOT GGML_CUDA_FA)
+    add_compile_definitions(GGML_CUDA_NO_FA)
+endif()
+
 if (CXX_IS_HIPCC)
     set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX)
     target_link_libraries(ggml-hip PRIVATE hip::device)
index 1bfc07c5d717a8dfe74829ee68a6359ce2de999d..2c75abf61d67242e35b1528351c124dc6c2c084e 100644 (file)
@@ -83,6 +83,10 @@ if (MUSAToolkit_FOUND)
         add_compile_definitions(GGML_CUDA_NO_VMM)
     endif()
 
+    if (NOT GGML_CUDA_FA)
+        add_compile_definitions(GGML_CUDA_NO_FA)
+    endif()
+
     if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
         add_compile_definitions(GGML_CUDA_F16)
     endif()