]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: skip fully masked-out KV in FA vec kernel (#13584)
authorJohannes Gäßler <redacted>
Tue, 20 May 2025 12:45:07 +0000 (14:45 +0200)
committerGitHub <redacted>
Tue, 20 May 2025 12:45:07 +0000 (14:45 +0200)
* CUDA: skip fully masked-out KV in FA vec kernel

ggml/src/ggml-cuda/fattn-vec-f16.cuh
ggml/src/ggml-cuda/fattn-vec-f32.cuh

index d96e392129848374f8eeee445a04d32717f86cd7..798a59b2778612a77f15b4ae4d64047edd45412d 100644 (file)
@@ -2,9 +2,9 @@
 #include "fattn-common.cuh"
 
 template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#ifndef GGML_USE_HIP
 __launch_bounds__(D, 1)
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#endif // GGML_USE_HIP
 static __global__ void flash_attn_vec_ext_f16(
         const char * __restrict__ Q,
         const char * __restrict__ K,
@@ -48,6 +48,12 @@ static __global__ void flash_attn_vec_ext_f16(
         NO_DEVICE_CODE;
         return;
     }
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+    if (ncols > 1) {
+        NO_DEVICE_CODE;
+        return;
+    }
+#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
 
     //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
 
@@ -91,6 +97,13 @@ static __global__ void flash_attn_vec_ext_f16(
             kqsum_shared[j][threadIdx.x] = 0.0f;
         }
     }
+
+    __shared__ half maskh_shared[ncols*D];
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        maskh_shared[j*D + tid] = 0.0f;
+    }
+
     __syncthreads();
 
     // Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers:
@@ -175,6 +188,35 @@ static __global__ void flash_attn_vec_ext_f16(
     for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
         // Calculate KQ tile and keep track of new maximum KQ values:
 
+        if (mask) {
+#pragma unroll
+            for (int j = 0; j < ncols; ++j) {
+                maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + k_VKQ_0 + tid];
+            }
+
+            __syncthreads();
+
+            // When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out.
+            // In such cases, skip the KV slice.
+            // On AMD __all_sync would not work correctly because it assumes a warp size of 64.
+#ifndef GGML_USE_HIP
+            bool skip = true;
+#pragma unroll
+            for (int j = 0; j < ncols; ++j) {
+#pragma unroll
+                for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+                    const int i = i0 + threadIdx.x;
+
+                    const float2 tmp = __half22float2(((const half2 *) maskh_shared)[j*(D/2) + i]);
+                    skip = skip && isinf(tmp.x) && isinf(tmp.y);
+                }
+            }
+            if (__all_sync(0xFFFFFFFF, skip)) {
+                continue;
+            }
+#endif // GGML_USE_HIP
+        }
+
         // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
         // see https://github.com/ggerganov/llama.cpp/pull/7061 .
         // Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable).
@@ -202,7 +244,7 @@ static __global__ void flash_attn_vec_ext_f16(
                     sum = logit_softcap*tanhf(sum);
                 }
 
-                sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
+                sum += maskh_shared[j*D + i_KQ];
 
                 if (ncols == 1) {
                     kqmax_new        = ggml_cuda_hmax(kqmax_new,        sum);
@@ -335,7 +377,9 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml
     float logit_softcap;
     memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
 
-    if (Q->ne[1] == 1) {
+    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+
+    if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) {
         constexpr int cols_per_block = 1;
         if (logit_softcap == 0.0f) {
             constexpr bool use_logit_softcap = false;
index 7064675d5ab3f5f352d7337fc981b536e0492da6..49c592ea59224ccae38b24c7e72e7697c46070d3 100644 (file)
@@ -2,9 +2,9 @@
 #include "fattn-common.cuh"
 
 template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#ifndef GGML_USE_HIP
 __launch_bounds__(D, 1)
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#endif // GGML_USE_HIP
 static __global__ void flash_attn_vec_ext_f32(
         const char * __restrict__ Q,
         const char * __restrict__ K,
@@ -60,6 +60,12 @@ static __global__ void flash_attn_vec_ext_f32(
         NO_DEVICE_CODE;
         return;
     }
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+    if (ncols > 1) {
+        NO_DEVICE_CODE;
+        return;
+    }
+#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
 
     //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
 
@@ -104,6 +110,13 @@ static __global__ void flash_attn_vec_ext_f32(
             kqsum_shared[j][threadIdx.x] = 0.0f;
         }
     }
+
+    __shared__ float maskf_shared[ncols*D];
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        maskf_shared[j*D + tid] = 0.0f;
+    }
+
     __syncthreads();
 
     // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
@@ -181,6 +194,34 @@ static __global__ void flash_attn_vec_ext_f32(
     for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
         // Calculate KQ tile and keep track of new maximum KQ values:
 
+        if (mask) {
+#pragma unroll
+            for (int j = 0; j < ncols; ++j) {
+                maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + k_VKQ_0 + tid]);
+            }
+
+            __syncthreads();
+
+            // When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out.
+            // In such cases, skip the KV slice.
+            // On AMD __all_sync would not work correctly because it assumes a warp size of 64.
+#ifndef GGML_USE_HIP
+            bool skip = true;
+#pragma unroll
+            for (int j = 0; j < ncols; ++j) {
+#pragma unroll
+                for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
+                    const int i = i0 + threadIdx.x;
+
+                    skip = skip && isinf(maskf_shared[j*D + i]);
+                }
+            }
+            if (__all_sync(0xFFFFFFFF, skip)) {
+                continue;
+            }
+#endif // GGML_USE_HIP
+        }
+
         float kqmax_new_arr[ncols];
 #pragma unroll
         for (int j = 0; j < ncols; ++j) {
@@ -204,7 +245,7 @@ static __global__ void flash_attn_vec_ext_f32(
                     sum = logit_softcap*tanhf(sum);
                 }
 
-                sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
+                sum += maskf_shared[j*D + i_KQ];
 
                 kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum);
 
@@ -326,7 +367,9 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
     float logit_softcap;
     memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
 
-    if (Q->ne[1] == 1) {
+    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+
+    if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) {
         constexpr int cols_per_block = 1;
         if (logit_softcap == 0.0f) {
             constexpr bool use_logit_softcap = false;