]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: skip masked KV slices for all FA kernels (llama/14924)
authorJohannes Gäßler <redacted>
Wed, 30 Jul 2025 13:46:13 +0000 (15:46 +0200)
committerGeorgi Gerganov <redacted>
Sat, 2 Aug 2025 14:51:21 +0000 (17:51 +0300)
src/ggml-cuda/common.cuh
src/ggml-cuda/fattn-common.cuh
src/ggml-cuda/fattn-mma-f16.cuh
src/ggml-cuda/fattn-tile-f16.cu
src/ggml-cuda/fattn-tile-f32.cu
src/ggml-cuda/fattn-vec-f16.cuh
src/ggml-cuda/fattn-vec-f32.cuh
src/ggml-cuda/fattn-wmma-f16.cu
src/ggml-cuda/fattn.cu

index 44daafbf78ad89d8b67e04637eaa5acfe371ba75..4560d85c4cfd3d4b0dcddfc96b637d7cf1508cfe 100644 (file)
@@ -432,6 +432,20 @@ static __global__ void reduce_rows_f32(const float * x, float * dst, const int n
     dst[row] = norm ? sum / ncols : sum;
 }
 
+template<int width = WARP_SIZE>
+static __device__ __forceinline__ int warp_reduce_all(int x) {
+#ifdef GGML_USE_HIP
+#pragma unroll
+    for (int offset = width/2; offset > 0; offset >>= 1) {
+        x = x && __shfl_xor_sync(0xffffffff, x, offset, width);
+    }
+    return x;
+#else
+    static_assert(width == WARP_SIZE, "width != WARP_SIZE not implemented");
+    return __all_sync(0xffffffff, x);
+#endif // GGML_USE_HIP
+}
+
 template<int width = WARP_SIZE>
 static __device__ __forceinline__ float warp_reduce_max(float x) {
 #pragma unroll
index 0cc74f284a15b8f921b31104e7ac195319258f7a..b6db446c6feaf2e81d2ec5c593613b6284dc5f91 100644 (file)
@@ -15,6 +15,7 @@ typedef void (* fattn_kernel_t)(
         const char * __restrict__ K,
         const char * __restrict__ V,
         const char * __restrict__ mask,
+        const int  * __restrict__ KV_max,
         float      * __restrict__ dst,
         float2     * __restrict__ dst_meta,
         const float scale,
@@ -500,6 +501,55 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
         nullptr;
 }
 
+template <int ncols1>
+__launch_bounds__(FATTN_KQ_STRIDE/2, 1)
+static __global__ void flash_attn_mask_to_KV_max(
+        const half2 * __restrict__ mask, int * __restrict__ KV_max, const int ne30, const int s31, const int s33) {
+    const int ne31     = gridDim.x;
+    const int tid      = threadIdx.x;
+    const int sequence = blockIdx.y;
+    const int jt       = blockIdx.x;
+
+    mask += sequence*s33 + jt*ncols1*s31;
+
+    __shared__ int buf_iw[WARP_SIZE];
+    if (tid < WARP_SIZE) {
+        buf_iw[tid] = 1;
+    }
+    __syncthreads();
+
+    int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE;
+    for (; KV_max_sj >= 0; KV_max_sj -= FATTN_KQ_STRIDE) {
+        int all_inf = 1;
+
+#pragma unroll
+        for (int j = 0; j < ncols1; ++j) {
+            const float2 tmp = __half22float2(mask[j*s31 + KV_max_sj/2 + tid]);
+            all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y));
+        }
+
+        all_inf = warp_reduce_all(all_inf);
+        if (tid % WARP_SIZE == 0) {
+            buf_iw[tid / WARP_SIZE] = all_inf;
+        }
+        __syncthreads();
+        all_inf = buf_iw[tid % WARP_SIZE];
+        __syncthreads();
+        all_inf = warp_reduce_all(all_inf);
+
+        if (!all_inf) {
+            KV_max_sj += FATTN_KQ_STRIDE;
+            break;
+        }
+    }
+
+    if (threadIdx.x != 0) {
+        return;
+    }
+
+    KV_max[sequence*ne31 + jt] = KV_max_sj;
+}
+
 template<int D, int ncols1, int ncols2> // D == head size
 __launch_bounds__(D, 1)
 static __global__ void flash_attn_stream_k_fixup(
@@ -711,6 +761,7 @@ void launch_fattn(
 
     ggml_cuda_pool_alloc<half>   K_f16(pool);
     ggml_cuda_pool_alloc<half>   V_f16(pool);
+    ggml_cuda_pool_alloc<int>    KV_max(pool);
     ggml_cuda_pool_alloc<float>  dst_tmp(pool);
     ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
 
@@ -779,11 +830,30 @@ void launch_fattn(
         V_data = (char *) V_f16.ptr;
     }
 
-    int parallel_blocks = 1;
-
     const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
     const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
 
+    // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
+    // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
+    //     multiple sequences of possibly different lengths.
+    if (mask && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
+        const int s31 = mask->nb[1] / sizeof(half2);
+        const int s33 = mask->nb[3] / sizeof(half2);
+
+        const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1);
+        const dim3 block_dim_KV_max(FATTN_KQ_STRIDE/2, 1, 1);
+
+        const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y;
+        const int iter_k = K->ne[1] / FATTN_KQ_STRIDE;
+
+        KV_max.alloc(ne_KV_max);
+        flash_attn_mask_to_KV_max<ncols1><<<blocks_num_KV_max, block_dim_KV_max, 0, main_stream>>>
+            ((const half2 *) mask->data, KV_max.ptr, iter_k, s31, s33);
+        CUDA_CHECK(cudaGetLastError());
+    }
+
+    int parallel_blocks = 1;
+
     const dim3 block_dim(warp_size, nwarps, 1);
     int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy.
     CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared));
@@ -870,6 +940,7 @@ void launch_fattn(
         K_data,
         V_data,
         mask ? ((const char *) mask->data) : nullptr,
+        KV_max.ptr,
         !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
         scale, max_bias, m0, m1, n_head_log2, logit_softcap,
         Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3],
index 8e847d361b455b886c60b3c2688797d08c835bc8..a86b95428f6ff6baedad28f09b4cd0374450e346 100644 (file)
@@ -392,7 +392,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
     }
 }
 
-template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter>
+template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles,
+    bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter>
 static __device__ __forceinline__ void flash_attn_ext_f16_iter(
         const float2 * const __restrict__ Q_f2,
         const half2  * const __restrict__ K_h2,
@@ -922,7 +923,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
     }
 
     // Iterate over ne11 == previous tokens:
-    for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
+    int kb0 = kb0_start;
+    for (; kb0 < kb0_stop-1; ++kb0) {
         constexpr bool last_iter = false;
         flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
             (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
@@ -932,7 +934,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
         constexpr bool last_iter = true;
         flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
             (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
-             ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
+             ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
     }
 
     // With multi-stage loading there is no __syncthreads at the end of the iter,
@@ -1204,6 +1206,7 @@ static __global__ void flash_attn_ext_f16(
         const char * __restrict__ K,
         const char * __restrict__ V,
         const char * __restrict__ mask,
+        const int  * __restrict__ KV_max,
         float      * __restrict__ dst,
         float2     * __restrict__ dst_meta,
         const float scale,
@@ -1280,7 +1283,11 @@ static __global__ void flash_attn_ext_f16(
         const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
 
         const int kb0_start_kernel = kb0_start * kb_niter;
-        const int kb0_stop_kernel  = kb0_stop  * kb_niter;
+        int       kb0_stop_kernel  = kb0_stop  * kb_niter;
+
+        if (KV_max) {
+            kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa);
+        }
 
         constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
         if (kb0_start == 0) {
@@ -1321,7 +1328,11 @@ static __global__ void flash_attn_ext_f16(
     const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
 
     const int kb0_start_kernel = kb0_start * kb_niter;
-    const int kb0_stop_kernel  = kb0_stop  * kb_niter;
+    int       kb0_stop_kernel  = kb0_stop  * kb_niter;
+
+    if (KV_max) {
+        kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa);
+    }
 
     constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
     constexpr bool needs_fixup = false;
index 678288c13e3e87e92053dfda0b25af9ce525ab6d..9d0b24ae7ec73fbb125b102c4ebc4959393c3de6 100644 (file)
@@ -13,6 +13,7 @@ static __global__ void flash_attn_tile_ext_f16(
         const char * __restrict__ K,
         const char * __restrict__ V,
         const char * __restrict__ mask,
+        const int  * __restrict__ KV_max,
         float      * __restrict__ dst,
         float2     * __restrict__ dst_meta,
         const float scale,
@@ -90,7 +91,8 @@ static __global__ void flash_attn_tile_ext_f16(
 
     __syncthreads();
 
-    for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) {
+    const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
+    for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) {
         // Calculate KQ tile and keep track of new maximum KQ values:
 
         half kqmax_new[ncols/nwarps];
index bc283d9a7e4d74db95b0f86db63333f7fda90d88..be72f76fb6538698b8c26caf623089a128a509cc 100644 (file)
@@ -13,6 +13,7 @@ static __global__ void flash_attn_tile_ext_f32(
         const char * __restrict__ K,
         const char * __restrict__ V,
         const char * __restrict__ mask,
+        const int  * __restrict__ KV_max,
         float      * __restrict__ dst,
         float2     * __restrict__ dst_meta,
         const float scale,
@@ -99,7 +100,8 @@ static __global__ void flash_attn_tile_ext_f32(
 
     __syncthreads();
 
-    for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F32; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F32) {
+    const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
+    for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F32; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F32) {
         // Calculate KQ tile and keep track of new maximum KQ values:
 
         float kqmax_new[ncols/nwarps];
index afef815ceaf1f5ff937bced722bd4ca0e2985e28..a2df2f66be0c41602af9651e3146201defdf581f 100644 (file)
@@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f16(
         const char * __restrict__ K,
         const char * __restrict__ V,
         const char * __restrict__ mask,
+        const int  * __restrict__ KV_max,
         float      * __restrict__ dst,
         float2     * __restrict__ dst_meta,
         const float scale,
@@ -177,10 +178,11 @@ static __global__ void flash_attn_vec_ext_f16(
 
     half2 VKQ[ncols] = {{0.0f, 0.0f}};
 
+    const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
     K     += blockIdx.y*D * nb11;
     V     += blockIdx.y*D * nb21;
     maskh += blockIdx.y*D;
-    for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D,
+    for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*D,
              // Increment pointers after each loop:
              K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) {
 
@@ -191,29 +193,7 @@ static __global__ void flash_attn_vec_ext_f16(
             for (int j = 0; j < ncols; ++j) {
                 maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + tid];
             }
-
             __syncthreads();
-
-            // When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out.
-            // In such cases, skip the KV slice.
-            // On AMD __all_sync would not work correctly because it assumes a warp size of 64.
-#ifndef GGML_USE_HIP
-            bool skip = true;
-#pragma unroll
-            for (int j = 0; j < ncols; ++j) {
-#pragma unroll
-                for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
-                    const int i = i0 + threadIdx.x;
-
-                    const float2 tmp = __half22float2(((const half2 *) maskh_shared)[j*(D/2) + i]);
-                    skip = skip && isinf(tmp.x) && isinf(tmp.y);
-                }
-            }
-            if (__all_sync(0xFFFFFFFF, skip)) {
-                __syncthreads();
-                continue;
-            }
-#endif // GGML_USE_HIP
         }
 
         // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
index 3595e29693ac57e0a3cdd35a25ccb36f88e97296..9ab0fc133b7a201ddf93d3667b8d0525628f249b 100644 (file)
@@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f32(
         const char * __restrict__ K,
         const char * __restrict__ V,
         const char * __restrict__ mask,
+        const int  * __restrict__ KV_max,
         float      * __restrict__ dst,
         float2     * __restrict__ dst_meta,
         const float scale,
@@ -183,10 +184,11 @@ static __global__ void flash_attn_vec_ext_f32(
 
     float VKQ[ncols] = {0.0f};
 
+    const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
     K     += blockIdx.y*D * nb11;
     V     += blockIdx.y*D * nb21;
     maskh += blockIdx.y*D;
-    for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D,
+    for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*D,
              // Increment pointers after each loop:
              K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) {
 
@@ -197,28 +199,7 @@ static __global__ void flash_attn_vec_ext_f32(
             for (int j = 0; j < ncols; ++j) {
                 maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + tid]);
             }
-
             __syncthreads();
-
-            // When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out.
-            // In such cases, skip the KV slice.
-            // On AMD __all_sync would not work correctly because it assumes a warp size of 64.
-#ifndef GGML_USE_HIP
-            bool skip = true;
-#pragma unroll
-            for (int j = 0; j < ncols; ++j) {
-#pragma unroll
-                for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
-                    const int i = i0 + threadIdx.x;
-
-                    skip = skip && isinf(maskf_shared[j*D + i]);
-                }
-            }
-            if (__all_sync(0xFFFFFFFF, skip)) {
-                __syncthreads();
-                continue;
-            }
-#endif // GGML_USE_HIP
         }
 
         float kqmax_new_arr[ncols];
index f4393ec5711d925e95a434ecf6b2a9bd026efd67..40554204b62f3a4da6c5ccfa3119c2c3178a93ad 100644 (file)
@@ -29,6 +29,7 @@ static __global__ void flash_attn_ext_f16(
         const char * __restrict__ K,
         const char * __restrict__ V,
         const char * __restrict__ mask,
+        const int  * __restrict__ KV_max,
         float      * __restrict__ dst,
         float2     * __restrict__ dst_meta,
         const float scale,
@@ -165,7 +166,8 @@ static __global__ void flash_attn_ext_f16(
     __syncthreads();
 
     // Iterate over ne11 == previous tokens:
-    for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE) {
+    const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
+    for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE) {
         // Calculate tile of KQ:
 #pragma unroll
         for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
index d9f1613051d3a14085f5498de0b05a79b9260845..a51136f6b8aa954fa203b79c745d7e4e3ff9c10f 100644 (file)
@@ -315,7 +315,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
 
     const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
     const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
-    const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion;
+    const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies &&
+        (Q->ne[3] > 1 || cc < GGML_CUDA_CC_ADA_LOVELACE) && !mma_needs_data_conversion;
     const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0;
     if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
         if (prec == GGML_PREC_DEFAULT) {