]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: faster Deepseek FA, add Turing support (#13435)
authorJohannes Gäßler <redacted>
Wed, 14 May 2025 14:08:20 +0000 (16:08 +0200)
committerGitHub <redacted>
Wed, 14 May 2025 14:08:20 +0000 (16:08 +0200)
ggml/src/ggml-cuda/fattn-common.cuh
ggml/src/ggml-cuda/fattn-mma-f16.cuh
ggml/src/ggml-cuda/fattn.cu
ggml/src/ggml-cuda/ggml-cuda.cu

index b7180d5955c29053ce54dcff1f74f1a596560254..a4fbd823638fab0db558c214cf48c1b971cac68b 100644 (file)
@@ -678,10 +678,14 @@ void launch_fattn(
 ) {
     constexpr int ncols = ncols1 * ncols2;
 
+    const bool is_mla = DV == 512; // TODO better parameterization
+
     const ggml_tensor * Q = dst->src[0];
     const ggml_tensor * K = dst->src[1];
     const ggml_tensor * V = dst->src[2];
 
+    GGML_ASSERT(V || is_mla);
+
     const ggml_tensor * mask = dst->src[3];
 
     ggml_tensor * KQV = dst;
@@ -689,6 +693,10 @@ void launch_fattn(
     GGML_ASSERT(Q->type == GGML_TYPE_F32);
     GGML_ASSERT(KQV->type == GGML_TYPE_F32);
 
+    GGML_ASSERT(      Q->nb[0] == ggml_element_size(Q));
+    GGML_ASSERT(      K->nb[0] == ggml_element_size(K));
+    GGML_ASSERT(!V || V->nb[0] == ggml_element_size(V));
+
     GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
     GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
         "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
@@ -713,10 +721,10 @@ void launch_fattn(
     size_t nb12 = K->nb[2];
     size_t nb13 = K->nb[3];
 
-    const char * V_data = (const char *) V->data;
-    size_t nb21 = V->nb[1];
-    size_t nb22 = V->nb[2];
-    size_t nb23 = V->nb[3];
+    const char * V_data = V ? (const char *) V->data : nullptr;
+    size_t nb21 = V ? V->nb[1] : nb11;
+    size_t nb22 = V ? V->nb[2] : nb12;
+    size_t nb23 = V ? V->nb[3] : nb13;
 
     if (need_f16_K && K->type != GGML_TYPE_F16) {
         GGML_ASSERT(ggml_is_contiguously_allocated(K));
@@ -733,7 +741,7 @@ void launch_fattn(
         nb13 = nb13*bs*sizeof(half)/ts;
     }
 
-    if (need_f16_V && V->type != GGML_TYPE_F16) {
+    if (V && need_f16_V && V->type != GGML_TYPE_F16) {
         GGML_ASSERT(ggml_is_contiguously_allocated(V));
         V_f16.alloc(ggml_nelements(V));
         to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
index 491780abd40626d4518a7c7d514654704c3c957d..be0329d0e0c0953fb76077aafc3d8c7e753a5235 100644 (file)
@@ -33,9 +33,30 @@ struct fattn_mma_f16_config< 64,  64> {
     static constexpr int  nwarps_max     = 4;
     static constexpr bool Q_in_reg       = true;
     static constexpr int  nstages_target = 2;
-    static constexpr int  nbatch_K2      = 32;
-    static constexpr int  nbatch_V2      = 32;
-    static constexpr int  nbatch_combine = 32;
+
+    static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
+        return 32;
+    }
+
+    static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
+        return 32;
+    }
+
+    static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
+        return 32;
+    }
+
+    static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
+        return 32;
+    }
+
+    static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
+        return 32;
+    }
+
+    static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
+        return 32;
+    }
 };
 
 template <>
@@ -44,9 +65,30 @@ struct fattn_mma_f16_config< 80,  80> {
     static constexpr int  nwarps_max     = 4;
     static constexpr bool Q_in_reg       = true;
     static constexpr int  nstages_target = 2;
-    static constexpr int  nbatch_K2      = 40;
-    static constexpr int  nbatch_V2      = 40;
-    static constexpr int  nbatch_combine = 40;
+
+    static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
+        return 40;
+    }
+
+    static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
+        return 40;
+    }
+
+    static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
+        return 40;
+    }
+
+    static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
+        return 40;
+    }
+
+    static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
+        return 40;
+    }
+
+    static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
+        return 40;
+    }
 };
 
 template <>
@@ -55,9 +97,30 @@ struct fattn_mma_f16_config< 96,  96> {
     static constexpr int  nwarps_max     = 4;
     static constexpr bool Q_in_reg       = true;
     static constexpr int  nstages_target = 2;
-    static constexpr int  nbatch_K2      = 48;
-    static constexpr int  nbatch_V2      = 48;
-    static constexpr int  nbatch_combine = 48;
+
+    static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
+        return 48;
+    }
+
+    static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
+        return 48;
+    }
+
+    static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
+        return 48;
+    }
+
+    static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
+        return 48;
+    }
+
+    static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
+        return 48;
+    }
+
+    static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
+        return 48;
+    }
 };
 
 template <>
@@ -66,9 +129,30 @@ struct fattn_mma_f16_config<112, 112> {
     static constexpr int  nwarps_max     = 4;
     static constexpr bool Q_in_reg       = true;
     static constexpr int  nstages_target = 2;
-    static constexpr int  nbatch_K2      = 56;
-    static constexpr int  nbatch_V2      = 56;
-    static constexpr int  nbatch_combine = 56;
+
+    static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
+        return 56;
+    }
+
+    static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
+        return 56;
+    }
+
+    static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
+        return 56;
+    }
+
+    static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
+        return 56;
+    }
+
+    static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
+        return 56;
+    }
+
+    static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
+        return 56;
+    }
 };
 
 template <>
@@ -77,9 +161,30 @@ struct fattn_mma_f16_config<128, 128> {
     static constexpr int  nwarps_max     = 4;
     static constexpr bool Q_in_reg       = true;
     static constexpr int  nstages_target = 2;
-    static constexpr int  nbatch_K2      = 64;
-    static constexpr int  nbatch_V2      = 64;
-    static constexpr int  nbatch_combine = 64;
+
+    static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
+        return 64;
+    }
+
+    static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
+        return 64;
+    }
+
+    static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
+        return 64;
+    }
+
+    static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
+        return 64;
+    }
+
+    static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
+        return 64;
+    }
+
+    static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
+        return 64;
+    }
 };
 
 template <>
@@ -88,9 +193,38 @@ struct fattn_mma_f16_config<256, 256> {
     static constexpr int  nwarps_max     = 4;
     static constexpr bool Q_in_reg       = true;
     static constexpr int  nstages_target = 2;
-    static constexpr int  nbatch_K2      = 128;
-    static constexpr int  nbatch_V2      = 128;
-    static constexpr int  nbatch_combine = 128;
+
+    static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
+        return 128;
+    }
+
+    static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
+        return 128;
+    }
+
+    static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
+        return 128;
+    }
+
+    static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
+        return 128;
+    }
+
+    static int get_nbatch_combine_host(const int cc, const int ncols) {
+        if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
+            return ncols <= 16 ? 128 : 64;
+        }
+        return 64;
+    }
+
+    static constexpr __device__ int get_nbatch_combine_device(int ncols) {
+#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
+        return ncols <= 16 ? 128 : 64;
+#else
+        GGML_UNUSED(ncols);
+        return 128;
+#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
+    }
 };
 
 template <>
@@ -99,9 +233,44 @@ struct fattn_mma_f16_config<576, 512> {
     static constexpr int  nwarps_max     = 8;
     static constexpr bool Q_in_reg       = false;
     static constexpr int  nstages_target = 1;
-    static constexpr int  nbatch_K2      = 160;
-    static constexpr int  nbatch_V2      = 128;
-    static constexpr int  nbatch_combine = 128;
+
+    static int get_nbatch_K2_host(const int cc, const int ncols) {
+        if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
+            return ncols <= 16 ? 96 : 160;
+        }
+        return ncols <= 16 ? 288 : 160;
+    }
+
+    static constexpr __device__ int get_nbatch_K2_device(int ncols) {
+#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
+        return ncols <= 16 ? 96 : 160;
+#else
+        return ncols <= 16 ? 288 : 160;
+#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
+    }
+
+    static int get_nbatch_V2_host(const int cc, const int ncols) {
+        if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
+            return ncols <= 16 ? 64 : 128;
+        }
+        return ncols <= 16 ? 256 : 128;
+    }
+
+    static constexpr __device__ int get_nbatch_V2_device(int ncols) {
+#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
+        return ncols <= 16 ? 64 : 128;
+#else
+        return ncols <= 16 ? 256 : 128;
+#endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
+    }
+
+    static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
+        return 128;
+    }
+
+    static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
+        return 128;
+    }
 };
 
 // ------------------------------------------------------------------------------------------------------------------
@@ -120,7 +289,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
 
         const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
 
-        auto load = [&] __device__ (const int n) {
+        auto load = [&] __device__ (auto n) {
             const int stride_k = WARP_SIZE >> n;
             const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
             const int k0_stop  =                             chunks_per_row - chunks_per_row % (1*stride_k);
@@ -223,7 +392,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
     }
 }
 
-template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup, bool last_iter>
+template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter>
 static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         const float2 * const __restrict__ Q_f2,
         const half2  * const __restrict__ K_h2,
@@ -261,10 +430,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
     constexpr int cols_per_warp   = ntiles * tile_B::I;
     constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
     constexpr int np              = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
+    constexpr int ncols           = ncols1 * ncols2;
+    constexpr int nbatch_K2       = c::get_nbatch_K2_device(ncols);
+    constexpr int nbatch_V2       = c::get_nbatch_V2_device(ncols);
 
-    constexpr int stride_tile_Q = DKQ/2        + 4;
-    constexpr int stride_tile_K = c::nbatch_K2 + 4;
-    constexpr int stride_tile_V = c::nbatch_V2 + 4;
+    constexpr int stride_tile_Q = DKQ/2     + 4;
+    constexpr int stride_tile_K = nbatch_K2 + 4;
+
+    static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
+    constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
 
     const int k_VKQ_0 = kb0 * c::nbatch_fa;
     tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles];
@@ -275,12 +449,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
     tile_C_KQ_16  * KQ_C_16  = (tile_C_KQ_16  *) KQ_C;
 
     if constexpr (nstages > 1) {
-        static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
+        static_assert(!mla, "multi-stage loading not implemented for MLA");
+        static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
         constexpr bool use_cp_async = true;
         cp_async_wait_all();
         __syncthreads();
         flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
-            (V_h2 + k_VKQ_0*stride_V, tile_V, c::nbatch_V2, stride_V);
+            (V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
     } else {
         constexpr bool use_cp_async = nstages == 1;
         if (ncols2 > 1 || mask_h2) {
@@ -289,8 +464,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
     }
 
 #pragma unroll
-    for (int k0_start = 0; k0_start < DKQ/2; k0_start += c::nbatch_K2) {
-        const int k0_stop = k0_start + c::nbatch_K2 < DKQ/2 ? k0_start + c::nbatch_K2 : DKQ/2;
+    for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) {
+        const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
         const int k0_diff = k0_stop - k0_start;
 
         if (nstages <= 1) {
@@ -537,16 +712,21 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
                     (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
             }
             flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
-                (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, c::nbatch_K2, stride_K);
+                (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
         }
     }
 
+
+    // For MLA K and V have the same data.
+    // Therefore, iterate over V in reverse and re-use the data if possible.
+    static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
+    constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
 #pragma unroll
-    for (int i0_start = 0; i0_start < DV; i0_start += 2*c::nbatch_V2) {
-        const int i0_stop = i0_start + 2*c::nbatch_V2 < DV ? i0_start + 2*c::nbatch_V2 : DV;
-        const int i0_diff = i0_stop - i0_start;
+    for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) {
+        const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0;
+        const int i0_diff  = i0_stop - i0_start;
 
-        if (nstages <= 1) {
+        if (nstages <= 1 && i0_start < reusable_cutoff) {
             constexpr bool use_cp_async = nstages == 1;
             flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
                 (V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
@@ -555,6 +735,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
             }
             __syncthreads();
         }
+        const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2;
 
         // Calculate VKQ tile:
 #pragma unroll
@@ -565,7 +746,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
                 const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
 
                 tile_A A;
-                load_ldmatrix_trans(A, tile_V + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
+                load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
                 if (ntiles == 1) {
                     mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
                 } else {
@@ -596,7 +777,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
 #endif // NEW_MMA_AVAILABLE
 }
 
-template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
+template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
 static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         const float2 * const __restrict__ Q_f2,
         const half2  * const __restrict__ K_h2,
@@ -632,13 +813,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
     constexpr int cols_per_warp   = ntiles * tile_B::I;
     constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
     constexpr int np              = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
+    constexpr int nbatch_K2       = c::get_nbatch_K2_device(ncols);
+    constexpr int nbatch_V2       = c::get_nbatch_V2_device(ncols);
 
     static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
 
-    constexpr int stride_tile_Q = DKQ/2        + 4;
-    constexpr int stride_tile_K = c::nbatch_K2 + 4;
-    constexpr int stride_tile_V = c::nbatch_V2 + 4;
+    constexpr int stride_tile_Q = DKQ/2     + 4;
+    constexpr int stride_tile_K = nbatch_K2 + 4;
 
+    static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
+    constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
     constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
 
     extern __shared__ half2 tile_Q[];
@@ -726,26 +910,26 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
 
     // Preload mask and K data for first iteration when using cp_async with multiple stages:
     if constexpr (nstages > 1) {
-        static_assert(c::nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
+        static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
         constexpr bool use_cp_async = true;
         if (ncols2 > 1 || mask_h2) {
             flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>
                 (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
         }
         flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
-            (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, c::nbatch_K2, stride_K);
+            (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
     }
 
     // Iterate over ne11 == previous tokens:
     for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
         constexpr bool last_iter = false;
-        flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
+        flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
             (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
              ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
     }
     { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
         constexpr bool last_iter = true;
-        flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
+        flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
             (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
              ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
     }
@@ -774,7 +958,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
     // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
     // So also write VKQ accumulators to shared memory in column-major format if np == 1.
 
-    constexpr int nbatch_combine = c::Q_in_reg ? DV/2 : DV/4;
+    constexpr int nbatch_combine = c::get_nbatch_combine_device(ncols);
     constexpr int tile_stride    = nbatch_combine + 4;
     static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine");
 
@@ -1012,7 +1196,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
 #endif // NEW_MMA_AVAILABLE
 }
 
-template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap>
+template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla>
 __launch_bounds__(nwarps*WARP_SIZE, 1)
 static __global__ void flash_attn_ext_f16(
         const char * __restrict__ Q,
@@ -1057,6 +1241,14 @@ static __global__ void flash_attn_ext_f16(
         NO_DEVICE_CODE;
         return;
     }
+#if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
+    if (ncols1*ncols2 > 32) {
+        NO_DEVICE_CODE;
+        return;
+    }
+#endif __CUDA_ARCH__ == GGML_CUDA_CC_TURING
+
+    static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV");
 
     typedef fattn_mma_f16_config<DKQ, DV> c;
 
@@ -1067,9 +1259,10 @@ static __global__ void flash_attn_ext_f16(
     const int stride_Q1   = nb01 / sizeof(float2);
     const int stride_Q2   = nb02 / sizeof(float2);
     const int stride_K    = nb11 / sizeof(half2);
-    const int stride_V    = nb21 / sizeof(half2);
     const int stride_mask = nb31 / sizeof(half2);
 
+    const int stride_V = mla ? stride_K : nb21 / sizeof(half2);
+
     const int iter_k = ne11 / FATTN_KQ_STRIDE;
     const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
 
@@ -1092,10 +1285,11 @@ static __global__ void flash_attn_ext_f16(
 
         const float2 * Q_f2    = (const float2 *) (Q + nb02* channel*ncols2);
         const half2  * K_h2    = (const half2  *) (K + nb12*(channel*ncols2 / gqa_ratio));
-        const half2  * V_h2    = (const half2  *) (V + nb22*(channel*ncols2 / gqa_ratio));
         const half2  * mask_h2 = ncols2 > 1 || mask ? (const half2  *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
         float2       * dstk    = ((float2 *) dst) + channel*(ncols2 * DV/2);
 
+        const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
+
         const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
 
         const int kb0_start_kernel = kb0_start * kb_niter;
@@ -1104,12 +1298,12 @@ static __global__ void flash_attn_ext_f16(
         constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
         if (kb0_start == 0) {
             constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
-            flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
+            flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
                 (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
                  ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
         } else {
             constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
-            flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
+            flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
                 (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
                  ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
         }
@@ -1130,10 +1324,11 @@ static __global__ void flash_attn_ext_f16(
 
     const float2 * Q_f2    = (const float2 *) (Q + nb02* channel*ncols2);
     const half2  * K_h2    = (const half2  *) (K + nb12*(channel*ncols2 / gqa_ratio));
-    const half2  * V_h2    = (const half2  *) (V + nb22*(channel*ncols2 / gqa_ratio)); // K and V have same shape
     const half2  * mask_h2 = ncols2 > 1 || mask ? (const half2  *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
     float2       * dstk    = ((float2 *) dst) + channel*(ncols2 * DV/2);
 
+    const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
+
     const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
 
     const int kb0_start_kernel = kb0_start * kb_niter;
@@ -1141,7 +1336,7 @@ static __global__ void flash_attn_ext_f16(
 
     constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
     constexpr bool needs_fixup = false;
-    flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, needs_fixup, is_fixup>
+    flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
         (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
          ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
 #else
@@ -1167,10 +1362,6 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
 
     typedef fattn_mma_f16_config<DKQ, DV> c;
 
-    constexpr int nbatch_K2      = c::nbatch_K2      < 1 ? DKQ/2 : c::nbatch_K2;
-    constexpr int nbatch_V2      = c::nbatch_V2      < 1 ? DV /2 : c::nbatch_V2;
-    constexpr int nbatch_combine = c::nbatch_combine < 1 ? DV /2 : c::nbatch_combine;
-
     const int nstages = cp_async_available(cc) ? c::nstages_target : 0;
 
     constexpr int ncols         = ncols1 * ncols2;
@@ -1180,15 +1371,21 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
     constexpr int nwarps_max_y  = c::nbatch_fa / tile_A::I;
     constexpr int nwarps        = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max;
 
+    constexpr bool mla = DKQ == 576;
+
+    const int nbatch_K2      = c::get_nbatch_K2_host     (cc, ncols);
+    const int nbatch_V2      = c::get_nbatch_K2_host     (cc, ncols);
+    const int nbatch_combine = c::get_nbatch_combine_host(cc, ncols);
+
     static_assert(DKQ   % tile_B::J     == 0, "bad DKQ");
     static_assert(DV    % tile_A::J     == 0, "bad DV");
     static_assert(ncols % cols_per_warp == 0, "bad ncols");
 
-    const size_t nbytes_shared_KV_1stage = c::nbatch_fa         * std::max(c::nbatch_K2 + 4,  c::nbatch_V2 + 4) * sizeof(half2);
-    const size_t nbytes_shared_KV_2stage = c::nbatch_fa         *         (c::nbatch_K2 + 4 + c::nbatch_V2 + 4) * sizeof(half2);
-    const size_t nbytes_shared_Q         = ncols                * (DKQ/2 + 4)                                   * sizeof(half2);
-    const size_t nbytes_shared_mask      = ncols1               * (c::nbatch_fa/2 + 4)                          * sizeof(half2);
-    const size_t nbytes_shared_combine   = nwarps*cols_per_warp * (nbatch_combine + 4)                          * sizeof(half2);
+    const size_t nbytes_shared_KV_1stage = c::nbatch_fa         * std::max(nbatch_K2 + 4,  nbatch_V2 + 4) * sizeof(half2);
+    const size_t nbytes_shared_KV_2stage = c::nbatch_fa         *         (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
+    const size_t nbytes_shared_Q         = ncols                * (DKQ/2 + 4)                             * sizeof(half2);
+    const size_t nbytes_shared_mask      = ncols1               * (c::nbatch_fa/2 + 4)                    * sizeof(half2);
+    const size_t nbytes_shared_combine   = nwarps*cols_per_warp * (nbatch_combine + 4)                    * sizeof(half2);
 
     const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage;
 
@@ -1202,7 +1399,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
     fattn_kernel_t fattn_kernel;
     if (logit_softcap == 0.0f) {
         constexpr bool use_logit_softcap = false;
-        fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>;
+        fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
 
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
         static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
@@ -1213,7 +1410,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
     } else {
         constexpr bool use_logit_softcap = true;
-        fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap>;
+        fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
 
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
         static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
index 9c5c803d02bc7b4d4a842b50303160bb2733eec4..6bc0096cc65e6229caf18879b07571fec1346f23 100644 (file)
@@ -10,6 +10,7 @@
 
 template <int DKQ, int DV, int ncols2>
 static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
     const ggml_tensor * Q = dst->src[0];
 
     if constexpr (ncols2 <= 8) {
@@ -24,7 +25,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_con
         return;
     }
 
-    if (Q->ne[1] <= 32/ncols2) {
+    if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING || Q->ne[1] <= 32/ncols2) {
         ggml_cuda_flash_attn_ext_mma_f16_case<DKQ, DV, 32/ncols2, ncols2>(ctx, dst);
         return;
     }
index b4b85abcda9e38cc06921061c38a985912ef9e89..02dc8c12dbd8c89b6a759963e7f518c5a56d89c4 100644 (file)
@@ -3222,7 +3222,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
 #endif // FLASH_ATTN_AVAILABLE
             if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
                 const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
-                if (!new_mma_available(cc) || cc < GGML_CUDA_CC_AMPERE) {
+                if (!new_mma_available(cc)) {
                     return false;
                 }
                 const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];