]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: add gqa_ratio 4 for GLM 4.7 flash (llama/18953)
authorAman Gupta <redacted>
Thu, 22 Jan 2026 10:51:53 +0000 (18:51 +0800)
committerGeorgi Gerganov <redacted>
Fri, 30 Jan 2026 11:49:29 +0000 (13:49 +0200)
src/ggml-cuda/fattn-mma-f16.cuh
src/ggml-cuda/fattn-tile.cuh
src/ggml-cuda/fattn.cu
src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu
src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
src/ggml-cuda/template-instances/generate_cu_files.py

index e53bbc0502ce2d7b2bbec0b2b51556fcc530c9fc..8cca89c2bfa7844cd61105e2d92fff55f6939e47 100644 (file)
@@ -432,7 +432,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
     constexpr int  ncols           = ncols1 * ncols2;
     constexpr int  cols_per_warp   = T_B_KQ::I;
     constexpr int  cols_per_thread = get_cols_per_thread();
-    constexpr int  np              = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
+    constexpr int  np              = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
     constexpr int  nbatch_fa       = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
     constexpr int  nbatch_K2       = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
     constexpr int  nbatch_V2       = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
@@ -510,7 +510,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
                 }
             }
         } else {
-            static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented");
 #pragma unroll
             for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
                 load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
@@ -522,14 +521,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
                     T_A_KQ K_A;
                     load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
 
-                    // Wide version of KQ_C is column-major
+                    if constexpr (cols_per_warp == 8) {
+                        mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
+                    } else {
+                        // Wide version of KQ_C is column-major
 #if defined(AMD_WMMA_AVAILABLE)
-                    // RDNA matrix C is column-major.
-                    mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
+                        // RDNA matrix C is column-major.
+                        mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
 #else
-                    // swap A and B for CUDA.
-                    mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
+                        // swap A and B for CUDA.
+                        mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
 #endif // defined(AMD_WMMA_AVAILABLE)
+                    }
                 }
             }
         }
@@ -953,7 +956,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
 
     constexpr int  cols_per_warp   = T_B_KQ::I;
     constexpr int  cols_per_thread = get_cols_per_thread();
-    constexpr int  np              = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
+    constexpr int  np              = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
     constexpr int  nbatch_fa       = ggml_cuda_fattn_mma_get_nbatch_fa     (DKQ, DV, ncols);
     constexpr int  nbatch_K2       = ggml_cuda_fattn_mma_get_nbatch_K2     (DKQ, DV, ncols);
     constexpr int  nbatch_V2       = ggml_cuda_fattn_mma_get_nbatch_V2     (DKQ, DV, ncols);
@@ -1484,6 +1487,13 @@ static __global__ void flash_attn_ext_f16(
         NO_DEVICE_CODE;
         return;
     }
+#ifdef VOLTA_MMA_AVAILABLE
+    if (ncols1*ncols2 < 32) {
+        NO_DEVICE_CODE;
+        return;
+    }
+#endif // VOLTA_MMA_AVAILABLE
+
 #if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
     if (ncols1*ncols2 > 32) {
         NO_DEVICE_CODE;
@@ -1728,3 +1738,8 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256,  64)
 extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
 extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
 extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
+
+// For GLM 4.7 Flash
+extern DECL_FATTN_MMA_F16_CASE(576, 512,  4,  4);
+extern DECL_FATTN_MMA_F16_CASE(576, 512,  8,  4);
+extern DECL_FATTN_MMA_F16_CASE(576, 512, 16,  4);
index f055da8e2be9364dae264dfdd67011d8bcd3c4f7..b6db5822818bd848d82f6c56845c55d2b2fe3dad 100644 (file)
@@ -68,6 +68,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2,  64,  64)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2,  64,  64)
 
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  64,  64)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2,  64,  64)
 
     return 0;
@@ -122,6 +124,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nv
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2,  32, 128)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2,  32,  64)
 
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  32,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  32,  64)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2,  32,  64)
 
     return 0;
@@ -183,6 +187,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2,  32, 128)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2,  32, 128)
 
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  64,  64)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2,  64,  64)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128,  64)
 
@@ -245,6 +251,8 @@ static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_am
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5,  32, 256)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3,  64, 128)
 
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  64,  64)
+    GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  64,  64)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4,  64,  64)
     GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128,  64)
 
@@ -1187,6 +1195,10 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm
             launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
             return;
         }
+        if (use_gqa_opt && gqa_ratio % 4 == 0) {
+            launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
+            return;
+        }
     }
 
     if constexpr (DV <= 256) {
index 598cda7daa05d3b31e382b906db12b47760d4e46..80c3bfbc271ad381d7f6dcacf2502bcff28502ec 100644 (file)
@@ -121,8 +121,12 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg
 
             GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
             const int gqa_ratio = Q->ne[2] / K->ne[2];
-            GGML_ASSERT(gqa_ratio % 16 == 0);
-            ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+            GGML_ASSERT(gqa_ratio % 4 == 0);
+            if (gqa_ratio % 16 == 0) {
+                ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);
+            } else {
+                ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512,  4>(ctx, dst);
+            }
         } break;
         default:
             GGML_ABORT("fatal error");
@@ -262,7 +266,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
             if (V->ne[0] != 512) {
                 return BEST_FATTN_KERNEL_NONE;
             }
-            if (!gqa_opt_applies || gqa_ratio % 16 != 0) {
+            if (!gqa_opt_applies || gqa_ratio % 4 != 0) {
                 return BEST_FATTN_KERNEL_NONE;
             }
             break;
index 2074e954a32f081819a66456a38936cdcad7c814..517993cb0681f575ee78744734158d014ca0c7b5 100644 (file)
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 16, 4);
 DECL_FATTN_MMA_F16_CASE(112, 112, 16, 4);
 DECL_FATTN_MMA_F16_CASE(128, 128, 16, 4);
 DECL_FATTN_MMA_F16_CASE(256, 256, 16, 4);
+DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
index 24c64cf000fecd47638b70228d93028a7fd744db..97b19c67ade9377784ceb48a31e0dbccfb02d1d1 100644 (file)
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 2, 4);
 DECL_FATTN_MMA_F16_CASE(112, 112, 2, 4);
 DECL_FATTN_MMA_F16_CASE(128, 128, 2, 4);
 DECL_FATTN_MMA_F16_CASE(256, 256, 2, 4);
+DECL_FATTN_MMA_F16_CASE(576, 512, 2, 4);
index 1ada657f194c4dd0e75fe95eb18e9e8b634a036d..989626dfa5e02e8ef3781d9a9ee189926db40ee9 100644 (file)
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 4, 4);
 DECL_FATTN_MMA_F16_CASE(112, 112, 4, 4);
 DECL_FATTN_MMA_F16_CASE(128, 128, 4, 4);
 DECL_FATTN_MMA_F16_CASE(256, 256, 4, 4);
+DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
index 86d4ffae27c289571f4ab7a49a687e60bfab38b1..173de7aac7d5744c2e003ff8a886b7043d10c901 100644 (file)
@@ -8,3 +8,4 @@ DECL_FATTN_MMA_F16_CASE(96, 96, 8, 4);
 DECL_FATTN_MMA_F16_CASE(112, 112, 8, 4);
 DECL_FATTN_MMA_F16_CASE(128, 128, 8, 4);
 DECL_FATTN_MMA_F16_CASE(256, 256, 8, 4);
+DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
index a5602da02bb08e179d2242ee0cd60ca392ebd72f..10be71ab5761895048cab845cfe26d7eabe4eae4 100755 (executable)
@@ -85,7 +85,7 @@ for ncols in [8, 16, 32, 64]:
                     continue
                 if head_size_kq != 576 and ncols2 == 16:
                     continue
-                if head_size_kq == 576 and ncols2 != 16:
+                if head_size_kq == 576 and ncols2 not in (4, 16):
                     continue
                 head_size_v = head_size_kq if head_size_kq != 576 else 512
                 f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size_kq=head_size_kq, head_size_v=head_size_v))