]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: Add Flash Attention Support for Head Dimension 512 (llama/20998)
authorAnav Prasad <redacted>
Wed, 1 Apr 2026 07:07:24 +0000 (07:07 +0000)
committerGeorgi Gerganov <redacted>
Wed, 1 Apr 2026 13:00:26 +0000 (16:00 +0300)
* flash attention support for head dimension 512 added

* FA D=512 - match 576 configs, limit ncols2, revert vec cap

* fix HIP tile kernel build for D=512

* fix HIP tile kernel occupancy for D=512 on AMD

* Apply suggestions from code review

Co-authored-by: Johannes Gäßler <redacted>
* fix tile FA compilation

---------

Co-authored-by: Johannes Gäßler <redacted>
14 files changed:
src/ggml-cuda/fattn-mma-f16.cuh
src/ggml-cuda/fattn-tile.cu
src/ggml-cuda/fattn-tile.cuh
src/ggml-cuda/fattn.cu
src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu
src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu
src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu
src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu
src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu [new file with mode: 0644]
src/ggml-cuda/template-instances/generate_cu_files.py

index fff70c8eb89fbe42db8a8c2ea0693d667c4945bd..b613ae61fb8919fbd6ac2ec55ff927d2007aeff3 100644 (file)
@@ -66,6 +66,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
     GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2,  32, 128, 128, 128, 2, true);
     GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2,  32, 128, 128, 128, 2, true);
 
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512,  8,  64, 4,  32, 256, 256, 128, 1, false);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16,  64, 4,  32, 256, 256, 128, 1, false);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2,  32, 128, 128, 128, 1, false);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1,  32, 128, 128, 128, 1, false);
+
     GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512,  8,  64, 4,  32, 288, 256, 128, 1, false);
     GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16,  64, 4,  32, 288, 256, 128, 1, false);
     GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2,  32, 160, 128, 128, 1, false);
@@ -80,6 +85,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
     GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2,  64, 128, 128,  64, 2, true);
     GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2,  64, 128, 128,  64, 2, true);
 
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512,  8,  64, 4,  32,  96,  64, 128, 1, false);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16,  64, 4,  32,  96,  64, 128, 1, false);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2,  32, 128, 128, 128, 1, false);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1,  32, 128, 128, 128, 1, false);
+
     GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512,  8,  64, 4,  32,  96,  64, 128, 1, false);
     GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16,  64, 4,  32,  96,  64, 128, 1, false);
     GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2,  32, 160, 128, 128, 1, false);
@@ -89,6 +99,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
 }
 
 static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) {
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512,  8,  64, 4,  32, 256, 256,  64, 1, false);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16,  64, 4,  32, 256, 256,  64, 1, false);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2,  32, 128, 128,  64, 1, false);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1,  32, 128, 128,  64, 1, false);
+
     GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512,  8,  64, 4,  32, 288, 256,  64, 1, false);
     GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16,  64, 4,  32, 288, 256,  64, 1, false);
     GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2,  32, 160, 128,  64, 1, false);
@@ -103,6 +118,10 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
     GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2,  64, 128, 128,  64, 2, true);
     GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2,  64, 128, 128,  64, 2, true);
 
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16,  64, 4,  32, 128, 128, 128, 1, false);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2,  32, 128, 128, 128, 1, false);
+    GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1,  32, 128, 128, 128, 1, false);
+
     GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16,  64, 4,  32,  96,  64, 128, 1, false);
     GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2,  32, 160, 128, 128, 1, false);
     GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1,  32, 160, 128, 128, 1, false);
@@ -1552,7 +1571,7 @@ static __global__ void flash_attn_ext_f16(
 #if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE))
 
     // Skip unused kernel variants for faster compilation:
-    if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
+    if (use_logit_softcap && !(DKQ == 128 || DKQ == 256 || DKQ == 512)) {
         NO_DEVICE_CODE;
         return;
     }
@@ -1815,6 +1834,15 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112,  64)
 DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128,  64)
 DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256,  64)
 
+extern DECL_FATTN_MMA_F16_CASE(512, 512,  2,  4);
+extern DECL_FATTN_MMA_F16_CASE(512, 512,  4,  4);
+extern DECL_FATTN_MMA_F16_CASE(512, 512,  8,  4);
+extern DECL_FATTN_MMA_F16_CASE(512, 512, 16,  4);
+extern DECL_FATTN_MMA_F16_CASE(512, 512,  1,  8);
+extern DECL_FATTN_MMA_F16_CASE(512, 512,  2,  8);
+extern DECL_FATTN_MMA_F16_CASE(512, 512,  4,  8);
+extern DECL_FATTN_MMA_F16_CASE(512, 512,  8,  8);
+
 // The number of viable configurations for Deepseek is very limited:
 extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
 extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
index 3fcb09b7a2ba358c7fead295789a50bf00594d5e..25b16e83cacd9753cc0100634f74901f7b946c93 100644 (file)
@@ -38,6 +38,10 @@ void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor
             GGML_ASSERT(V->ne[0] == K->ne[0]);
             ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst);
         } break;
+        case 512: {
+            GGML_ASSERT(V->ne[0] == K->ne[0]);
+            ggml_cuda_flash_attn_ext_tile_case<512, 512>(ctx, dst);
+        } break;
         case 576: {
             GGML_ASSERT(V->ne[0] == 512);
             ggml_cuda_flash_attn_ext_tile_case<576, 512>(ctx, dst);
index f3fa80ab23d09b3f8f634c1cb29326e1dfac2bbd..26721cc4c7de249bd20ec66ae73aeb47ca89980f 100644 (file)
@@ -68,6 +68,10 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2,  64,  64)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2,  64,  64)
 
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512,  4, 128, 2,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512,  8, 256, 2,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2,  64,  64)
+
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  64,  64)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  64,  64)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2,  64,  64)
@@ -124,6 +128,10 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2,  32, 128)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2,  32,  64)
 
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512,  4, 128, 2,  32,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512,  8, 256, 2,  32,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2,  32,  64)
+
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  32,  64)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  32,  64)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2,  32,  64)
@@ -187,6 +195,11 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2,  32, 128)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2,  32, 128)
 
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512,  4, 128, 2,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512,  8, 256, 2,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 2,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 32, 512, 1, 128,  64)
+
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  64,  64)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  64,  64)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2,  64,  64)
@@ -251,6 +264,11 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5,  32, 256)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3,  64, 128)
 
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512,  4, 128, 2,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512,  8, 256, 2,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 16, 256, 4,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(512, 512, 32, 256, 2, 128,  64)
+
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  64,  64)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  64,  64)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4,  64,  64)
@@ -767,7 +785,7 @@ static __global__ void flash_attn_tile(
 #ifdef GGML_USE_WMMA_FATTN
             (ncols2 != 1 && DV != 40 && DV != 72 && DV != 512) ||
 #endif // GGML_USE_WMMA_FATTN
-            (use_logit_softcap && !(DV == 128 || DV == 256))
+            (use_logit_softcap && !(DV == 128 || DV == 256 || DV == 512))
     ) {
         GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
             max_bias, m0, m1, n_head_log2, logit_softcap,
@@ -1192,7 +1210,7 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
     const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX;
     const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0;
 
-    if constexpr (DV == 512) {
+    if constexpr (DKQ == 576) {
         if (use_gqa_opt && gqa_ratio % 16 == 0) {
             launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
             return;
@@ -1203,7 +1221,7 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
         }
     }
 
-    if constexpr (DV <= 256) {
+    if constexpr (DKQ <= 512) {
         if (use_gqa_opt && gqa_ratio % 8 == 0) {
             launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
             return;
@@ -1214,13 +1232,15 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
             return;
         }
 
-        if (use_gqa_opt && gqa_ratio % 2 == 0) {
-            launch_fattn_tile_switch_ncols1<DKQ, DV, 2, use_logit_softcap>(ctx, dst);
+        if constexpr (DV <= 256) {
+            if (use_gqa_opt && gqa_ratio % 2 == 0) {
+                launch_fattn_tile_switch_ncols1<DKQ, DV, 2, use_logit_softcap>(ctx, dst);
+                return;
+            }
+
+            launch_fattn_tile_switch_ncols1<DKQ, DV, 1, use_logit_softcap>(ctx, dst);
             return;
         }
-
-        launch_fattn_tile_switch_ncols1<DKQ, DV, 1, use_logit_softcap>(ctx, dst);
-        return;
     }
     GGML_ABORT("fatal error");
 }
@@ -1255,4 +1275,5 @@ extern DECL_FATTN_TILE_CASE( 96,  96);
 extern DECL_FATTN_TILE_CASE(112, 112);
 extern DECL_FATTN_TILE_CASE(128, 128);
 extern DECL_FATTN_TILE_CASE(256, 256);
+extern DECL_FATTN_TILE_CASE(512, 512);
 extern DECL_FATTN_TILE_CASE(576, 512);
index a25a890db6dafca57f7b9fef0d119ee40400b295..a21c53610483cad03ee00a0678eb594d477becff 100644 (file)
@@ -135,6 +135,10 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
             GGML_ASSERT(V->ne[0] == 256);
             ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<256, 256>(ctx, dst);
             break;
+        case 512:
+            GGML_ASSERT(V->ne[0] == 512);
+            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<512, 512>(ctx, dst);
+            break;
         case 576: {
             // For Deepseek, go straight to the ncols1 switch to avoid compiling unnecessary kernels.
             GGML_ASSERT(V->ne[0] == 512);
@@ -336,7 +340,8 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
         case 128:
         case 112:
         case 256:
-            if (V->ne[0] != K->ne[0]) {
+        case 512:
+            if (!gqa_opt_applies) {
                 return BEST_FATTN_KERNEL_NONE;
             }
             break;
@@ -424,7 +429,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
     }
 
     // Use the WMMA kernel if possible:
-    if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 576) {
+    if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 512 && Q->ne[0] != 576) {
         if (can_use_vector_kernel && Q->ne[1] <= 2) {
             return BEST_FATTN_KERNEL_VEC;
         }
@@ -457,7 +462,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
     }
 
     // Use MFMA flash attention for CDNA (MI100+):
-    if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 576) {
+    if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 512 && Q->ne[0] != 576) {
         const int64_t eff_nq = Q->ne[1] * (gqa_opt_applies ? gqa_ratio : 1);
         // MMA vs tile crossover benchmarked on MI300X @ d32768:
         //   hsk=64  (gqa=4): MMA wins at eff >= 128 (+11%)
index dc16829021f9019b032eb65004925e516c3491a0..22d383173f36213d5456b5a28474979ae1a7578d 100644 (file)
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 1, 8);
 DECL_FATTN_MMA_F16_CASE(112, 112, 1, 8);
 DECL_FATTN_MMA_F16_CASE(128, 128, 1, 8);
 DECL_FATTN_MMA_F16_CASE(256, 256, 1, 8);
+DECL_FATTN_MMA_F16_CASE(512, 512, 1, 8);
index 517993cb0681f575ee78744734158d014ca0c7b5..d2415bfa957f31075f4611250fe7d43c4a4a470d 100644 (file)
@@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
 DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
 DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
 DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);
+DECL_FATTN_MMA_F16_CASE(512, 512, 16, 4);
 DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
index 97b19c67ade9377784ceb48a31e0dbccfb02d1d1..8eec1d74e293e4e77ba0c18aa4995eb90e1fe238 100644 (file)
@@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
 DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
 DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
 DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);
+DECL_FATTN_MMA_F16_CASE(512, 512, 2, 4);
 DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);
index 163b1d939e49dedfa88bde15681da5410db6a7e3..84b674cd05a6f2664678fd757d6b82dffbe56049 100644 (file)
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 8);
 DECL_FATTN_MMA_F16_CASE(112, 112, 2, 8);
 DECL_FATTN_MMA_F16_CASE(128, 128, 2, 8);
 DECL_FATTN_MMA_F16_CASE(256, 256, 2, 8);
+DECL_FATTN_MMA_F16_CASE(512, 512, 2, 8);
index 989626dfa5e02e8ef3781d9a9ee189926db40ee9..3475dfea08a1349a9813b2b448dc4b8fc6402062 100644 (file)
@@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
 DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
 DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
 DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);
+DECL_FATTN_MMA_F16_CASE(512, 512, 4, 4);
 DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
index bad296b4141e0ce124113ecc3c7dd3f7643682b5..5906398db9123753462ca05c9c9c95b4dae28cbe 100644 (file)
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 8);
 DECL_FATTN_MMA_F16_CASE(112, 112, 4, 8);
 DECL_FATTN_MMA_F16_CASE(128, 128, 4, 8);
 DECL_FATTN_MMA_F16_CASE(256, 256, 4, 8);
+DECL_FATTN_MMA_F16_CASE(512, 512, 4, 8);
index 173de7aac7d5744c2e003ff8a886b7043d10c901..684cd25ce0decb0bd7deb9b9fe9ebc0eb4d672e7 100644 (file)
@@ -8,4 +8,5 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
 DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
 DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
 DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);
+DECL_FATTN_MMA_F16_CASE(512, 512, 8, 4);
 DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
index 680a13ca6de5826f8ff741c1fc9ee2d1c12756e4..4bc60d62f91042fdc2e57b7b8a65fc9f95dd347d 100644 (file)
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 8);
 DECL_FATTN_MMA_F16_CASE(112, 112, 8, 8);
 DECL_FATTN_MMA_F16_CASE(128, 128, 8, 8);
 DECL_FATTN_MMA_F16_CASE(256, 256, 8, 8);
+DECL_FATTN_MMA_F16_CASE(512, 512, 8, 8);
diff --git a/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu b/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu
new file mode 100644 (file)
index 0000000..7c61d8d
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.cuh"
+
+DECL_FATTN_TILE_CASE(512, 512);
index 3b5ab12fc408648c0a9363382ebb94be87d95b79..b7b5832293ec5ae043bfa93d69e60ab7c49926fa 100755 (executable)
@@ -3,7 +3,7 @@
 from glob import glob
 import os
 
-HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 576]
+HEAD_SIZES_KQ = [40, 64, 72, 80, 96, 112, 128, 256, 512, 576]
 
 TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_BF16"]
 
@@ -83,6 +83,8 @@ for ncols in [8, 16, 32, 64]:
                     continue
                 if head_size_kq == 72:
                     continue
+                if head_size_kq == 512 and ncols2 not in (4, 8):
+                    continue
                 if head_size_kq != 576 and ncols2 in (16, 32):
                     continue
                 if head_size_kq == 576 and ncols2 not in (4, 16, 32):