]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml: CUDA: add head size 72 for flash-attn (#16962)
authortheo77186 <redacted>
Mon, 3 Nov 2025 13:29:11 +0000 (14:29 +0100)
committerGitHub <redacted>
Mon, 3 Nov 2025 13:29:11 +0000 (14:29 +0100)
ggml/src/ggml-cuda/fattn-tile.cu
ggml/src/ggml-cuda/fattn-tile.cuh
ggml/src/ggml-cuda/fattn.cu
ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu [new file with mode: 0644]
ggml/src/ggml-cuda/template-instances/generate_cu_files.py

index 3a5806d9091d763d344817498d378d018de71baa..3fcb09b7a2ba358c7fead295789a50bf00594d5e 100644 (file)
@@ -14,6 +14,10 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor
             GGML_ASSERT(V->ne[0] == K->ne[0]);
             ggml_cuda_flash_attn_ext_tile_case< 64,  64>(ctx, dst);
         } break;
+        case  72: {
+            GGML_ASSERT(V->ne[0] == K->ne[0]);
+            ggml_cuda_flash_attn_ext_tile_case< 72,  72>(ctx, dst);
+        } break;
         case  80: {
             GGML_ASSERT(V->ne[0] == K->ne[0]);
             ggml_cuda_flash_attn_ext_tile_case< 80,  80>(ctx, dst);
index 2b60b3bb13563d33da26cc63e1e6fceca683da89..c358aa1e87ef0095a4ffc608897dfce1523fc13a 100644 (file)
@@ -6,7 +6,7 @@
 // nbatch_K == number of K columns to load in parallel for KQ calculation
 
 // TODO optimize kernel parameters for FP16 NVIDIA (P100)
-// TODO optimize kernel parameters for head sizes 40, 80, 96, 112
+// TODO optimize kernel parameters for head sizes 40, 72, 80, 96, 112
 
 // The ROCm compiler cannot handle templating in __launch_bounds__.
 // As a workaround, define a macro to package the kernel parameters as uint32_t:
@@ -32,6 +32,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
     GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 16, 256, 2,  64,  64)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 32, 256, 2,  64,  64)
 
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72,  2,  64, 2,  64,  72)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72,  4, 128, 2,  64,  72)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72,  8, 256, 2,  64,  72)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72, 16, 256, 2,  64,  72)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72, 32, 256, 2,  64,  72)
+
     GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  2,  64, 2,  64,  40)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  4, 128, 2,  64,  40)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  8, 256, 2,  64,  40)
@@ -80,6 +86,12 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
     GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 16, 128, 3,  64,  64)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 32, 256, 2,  64,  64)
 
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72,  2,  64, 2,  32,  72)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72,  4, 128, 2,  32,  72)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72,  8, 256, 2,  32,  72)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72, 16, 256, 2,  32,  72)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72, 32, 256, 2,  32,  72)
+
     GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  2,  64, 2,  32,  40)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  4, 128, 2,  32,  40)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  8, 256, 2,  32,  40)
@@ -130,6 +142,13 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
     GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 32, 256, 2,  64,  64)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 64, 256, 2,  64,  64)
 
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72,  2,  64, 2,  32,  72)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72,  4, 128, 2,  32,  72)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72,  8, 256, 2,  32,  72)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72, 16, 256, 2,  32,  72)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72, 32, 256, 2,  32,  72)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72, 64, 256, 2,  32,  72)
+
     GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  2,  64, 2,  32,  40)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  4, 128, 2,  32,  40)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  8, 256, 2,  32,  40)
@@ -185,6 +204,13 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
     GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 32, 128, 4,  64,  64)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64,  64, 64, 128, 5,  64,  64)
 
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72,  2,  64, 2,  32,  72)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72,  4, 128, 2,  32,  72)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72,  8, 256, 2,  32,  72)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72, 16, 256, 2,  32,  72)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72, 32, 256, 2,  32,  72)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE( 72,  72, 64, 256, 2,  32,  72)
+
     GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  2,  64, 2,  32,  40)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  4, 128, 2,  32,  40)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80,  80,  8, 256, 2,  32,  40)
@@ -723,7 +749,7 @@ static __global__ void flash_attn_tile(
 
     if (
 #ifdef GGML_USE_WMMA_FATTN
-            (ncols2 != 1 && DV != 40 && DV != 512) ||
+            (ncols2 != 1 && DV != 40 && DV != 72 && DV != 512) ||
 #endif // GGML_USE_WMMA_FATTN
             (use_logit_softcap && !(DV == 128 || DV == 256))
     ) {
@@ -1198,6 +1224,7 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor
 
 extern DECL_FATTN_TILE_CASE( 40,  40);
 extern DECL_FATTN_TILE_CASE( 64,  64);
+extern DECL_FATTN_TILE_CASE( 72,  72);
 extern DECL_FATTN_TILE_CASE( 80,  80);
 extern DECL_FATTN_TILE_CASE( 96,  96);
 extern DECL_FATTN_TILE_CASE(112, 112);
index 7dee032c291373626910bab6719dbde5c325b265..82405991cea6e5f02aa2958d2510f1be161278d6 100644 (file)
@@ -223,6 +223,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
     switch (K->ne[0]) {
         case  40:
         case  64:
+        case  72:
         case  80:
         case  96:
         case 128:
@@ -275,7 +276,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
     const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
 
     // If Turing tensor cores available, use them:
-    if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40) {
+    if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72) {
         if (can_use_vector_kernel) {
             if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
                 if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
@@ -301,7 +302,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
     }
 
     // Use the WMMA kernel if possible:
-    if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 576) {
+    if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 576) {
         if (can_use_vector_kernel && Q->ne[1] <= 2) {
             return BEST_FATTN_KERNEL_VEC;
         }
diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu
new file mode 100644 (file)
index 0000000..8f9d531
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.cuh"
+
+DECL_FATTN_TILE_CASE(72, 72);
index 81a986f38cacff056f6cd5e432e7565e923f69f7..a5602da02bb08e179d2242ee0cd60ca392ebd72f 100755 (executable)
@@ -3,7 +3,7 @@
 from glob import glob
 import os
 
-HEAD_SIZES_KQ = [40, 64, 80, 96, 112, 128, 256, 576]
+HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 576]
 
 TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"]
 
@@ -81,6 +81,8 @@ for ncols in [8, 16, 32, 64]:
             for head_size_kq in HEAD_SIZES_KQ:
                 if head_size_kq == 40:
                     continue
+                if head_size_kq == 72:
+                    continue
                 if head_size_kq != 576 and ncols2 == 16:
                     continue
                 if head_size_kq == 576 and ncols2 != 16: