]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CPU/CUDA: Gemma 2 FlashAttention support (llama/8542)
authorJohannes Gäßler <redacted>
Sat, 24 Aug 2024 19:34:59 +0000 (21:34 +0200)
committerGeorgi Gerganov <redacted>
Tue, 27 Aug 2024 19:01:14 +0000 (22:01 +0300)
* CPU/CUDA: Gemma 2 FlashAttention support

* apply logit_softcap to scale in kernel

* disable logit softcapping tests on Metal

* remove metal check

include/ggml.h
src/ggml-cuda/fattn-common.cuh
src/ggml-cuda/fattn-tile-f16.cu
src/ggml-cuda/fattn-tile-f32.cu
src/ggml-cuda/fattn-vec-f16.cuh
src/ggml-cuda/fattn-vec-f32.cuh
src/ggml-cuda/fattn-wmma-f16.cuh
src/ggml-cuda/fattn.cu
src/ggml-metal.m
src/ggml.c
tests/test-backend-ops.cpp

index 74fc98882808b5439f6341563fb754841a994f3a..b11d047aeda7d057f74f9b606e8d2f837e44a16c 100644 (file)
@@ -1807,7 +1807,8 @@ extern "C" {
             struct ggml_tensor  * v,
             struct ggml_tensor  * mask,
             float                 scale,
-            float                 max_bias);
+            float                 max_bias,
+            float                 logit_softcap);
 
     GGML_API void ggml_flash_attn_ext_set_prec(
             struct ggml_tensor * a,
index 950fd93dfe1eecd61bf6c610ff11dcc476c48f2e..1fb5c09c3b179417c71b70a6b78808d21abfde4c 100644 (file)
@@ -22,6 +22,7 @@ typedef void (* fattn_kernel_t)(
         const float m0,
         const float m1,
         const uint32_t n_head_log2,
+        const float logit_softcap,
         const int ne00,
         const int ne01,
         const int ne02,
@@ -657,11 +658,17 @@ void launch_fattn(
     const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
     const int  shmem = 0;
 
-    float scale    = 1.0f;
-    float max_bias = 0.0f;
+    float scale         = 1.0f;
+    float max_bias      = 0.0f;
+    float logit_softcap = 0.0f;
 
-    memcpy(&scale,    (float *) KQV->op_params + 0, sizeof(float));
-    memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
+    memcpy(&scale,         (float *) KQV->op_params + 0, sizeof(float));
+    memcpy(&max_bias,      (float *) KQV->op_params + 1, sizeof(float));
+    memcpy(&logit_softcap, (float *) KQV->op_params + 2, sizeof(float));
+
+    if (logit_softcap != 0.0f) {
+        scale /= logit_softcap;
+    }
 
     const uint32_t n_head      = Q->ne[2];
     const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
@@ -675,7 +682,7 @@ void launch_fattn(
         V_data,
         mask ? ((const char *) mask->data) : nullptr,
         (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
-        scale, max_bias, m0, m1, n_head_log2,
+        scale, max_bias, m0, m1, n_head_log2, logit_softcap,
         Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
         K->ne[0], K->ne[1], K->ne[2], K->ne[3],
         mask ? mask->ne[1] : 0, mask ?  mask->nb[1] : 0,
index 1b2fd500b746ccfcec142a5d3f1a855a86bf7a5d..342f2eb665312d7dd93003a60bcd195ac032ac2e 100644 (file)
@@ -4,7 +4,7 @@
 
 #define FATTN_KQ_STRIDE_TILE_F16 64
 
-template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
+template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
 #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
 __launch_bounds__(nwarps*WARP_SIZE, 1)
 #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
@@ -20,6 +20,7 @@ static __global__ void flash_attn_tile_ext_f16(
         const float m0,
         const float m1,
         const uint32_t n_head_log2,
+        const float logit_softcap,
         const int ne00,
         const int ne01,
         const int ne02,
@@ -44,6 +45,12 @@ static __global__ void flash_attn_tile_ext_f16(
         const int ne2,
         const int ne3) {
 #ifdef FP16_AVAILABLE
+    // Skip unused kernel variants for faster compilation:
+    if (use_logit_softcap && !(D == 128 || D == 256)) {
+        NO_DEVICE_CODE;
+        return;
+    }
+
     //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
 
     const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
@@ -154,7 +161,13 @@ static __global__ void flash_attn_tile_ext_f16(
             for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
                 const int j_KQ = j_KQ_0 + threadIdx.y;
 
-                half sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
+                half sum;
+                if (use_logit_softcap) {
+                    const float2 tmp = __half22float2(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
+                    sum = logit_softcap * tanhf(tmp.x + tmp.y);
+                } else {
+                    sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
+                }
                 sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
 
                 kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum);
@@ -270,20 +283,20 @@ static __global__ void flash_attn_tile_ext_f16(
 #endif // FP16_AVAILABLE
 }
 
-template <int cols_per_block, int parallel_blocks>
+template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
 void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * Q = dst->src[0];
     switch (Q->ne[0]) {
         case  64: {
             constexpr int      D = 64;
             constexpr int nwarps = 8;
-            fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
+            fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
             launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
         } break;
         case 128: {
             constexpr int      D = 128;
             constexpr int nwarps = 8;
-            fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
+            fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
             launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
         } break;
         default: {
@@ -296,24 +309,45 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten
     const ggml_tensor * KQV = dst;
     const ggml_tensor * Q   = dst->src[0];
 
-    const int32_t precision = KQV->op_params[2];
+    const int32_t precision = KQV->op_params[3];
     GGML_ASSERT(precision == GGML_PREC_DEFAULT);
 
+    float logit_softcap;
+    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
     if (Q->ne[1] <= 16) {
         constexpr int cols_per_block = 16;
         constexpr int parallel_blocks = 4;
-        launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+        }
         return;
     }
 
     if (Q->ne[1] <= 32) {
         constexpr int cols_per_block = 32;
         constexpr int parallel_blocks = 4;
-        launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+        }
         return;
     }
 
     constexpr int cols_per_block = 32;
     constexpr int parallel_blocks = 1;
-    launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
+    if (logit_softcap == 0.0f) {
+        constexpr bool use_logit_softcap = false;
+        launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+    } else {
+        constexpr bool use_logit_softcap = true;
+        launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+    }
 }
index f3e68dbfa6a6a6e0a5b7386c130d4355a4317903..827437ca0ad1ff6efdbaef3fba2d49bf7008ccc1 100644 (file)
@@ -4,7 +4,7 @@
 
 #define FATTN_KQ_STRIDE_TILE_F32 32
 
-template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
+template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
 #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
 __launch_bounds__(nwarps*WARP_SIZE, 1)
 #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
@@ -20,6 +20,7 @@ static __global__ void flash_attn_tile_ext_f32(
         const float m0,
         const float m1,
         const uint32_t n_head_log2,
+        const float logit_softcap,
         const int ne00,
         const int ne01,
         const int ne02,
@@ -43,6 +44,12 @@ static __global__ void flash_attn_tile_ext_f32(
         const int ne1,
         const int ne2,
         const int ne3) {
+    // Skip unused kernel variants for faster compilation:
+    if (use_logit_softcap && !(D == 128 || D == 256)) {
+        NO_DEVICE_CODE;
+        return;
+    }
+
     //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
 
     const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
@@ -151,6 +158,10 @@ static __global__ void flash_attn_tile_ext_f32(
             for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
                 const int j_KQ = j_KQ_0 + threadIdx.y;
 
+                if (use_logit_softcap) {
+                    sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] = logit_softcap * tanhf(sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
+                }
+
                 sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
 
                 kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
@@ -267,20 +278,20 @@ static __global__ void flash_attn_tile_ext_f32(
     }
 }
 
-template <int cols_per_block, int parallel_blocks>
+template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
 void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * Q = dst->src[0];
     switch (Q->ne[0]) {
         case  64: {
             constexpr int      D = 64;
             constexpr int nwarps = 8;
-            fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
+            fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
             launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
         } break;
         case 128: {
             constexpr int      D = 128;
             constexpr int nwarps = 8;
-            fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
+            fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
             launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
         } break;
         default: {
@@ -290,23 +301,45 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
 }
 
 void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * KQV = dst;
     const ggml_tensor * Q = dst->src[0];
 
+    float logit_softcap;
+    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
     if (Q->ne[1] <= 16) {
         constexpr int cols_per_block = 16;
         constexpr int parallel_blocks = 4;
-        launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+        }
         return;
     }
 
     if (Q->ne[1] <= 32) {
         constexpr int cols_per_block = 32;
         constexpr int parallel_blocks = 4;
-        launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+        }
         return;
     }
 
     constexpr int cols_per_block = 32;
     constexpr int parallel_blocks = 1;
-    launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
+    if (logit_softcap == 0.0f) {
+        constexpr bool use_logit_softcap = false;
+        launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+    } else {
+        constexpr bool use_logit_softcap = true;
+        launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+    }
 }
index 02a4ad072bbd6f70d8a4b47b7dbeaf674dcb790a..448a9a9054cca112cdfd0b0c4de1e1c6152a38af 100644 (file)
@@ -1,7 +1,7 @@
 #include "common.cuh"
 #include "fattn-common.cuh"
 
-template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V> // D == head size
+template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
 #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
 __launch_bounds__(D, 1)
 #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
@@ -17,6 +17,7 @@ static __global__ void flash_attn_vec_ext_f16(
         const float m0,
         const float m1,
         const uint32_t n_head_log2,
+        const float logit_softcap,
         const int ne00,
         const int ne01,
         const int ne02,
@@ -41,6 +42,12 @@ static __global__ void flash_attn_vec_ext_f16(
         const int ne2,
         const int ne3) {
 #ifdef FP16_AVAILABLE
+    // Skip unused kernel variants for faster compilation:
+    if (use_logit_softcap && !(D == 128 || D == 256)) {
+        NO_DEVICE_CODE;
+        return;
+    }
+
     //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
 
     constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16<D>(type_K);
@@ -190,6 +197,11 @@ static __global__ void flash_attn_vec_ext_f16(
             for (int j = 0; j < ncols; ++j) {
                 half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
                 sum = warp_reduce_sum(sum);
+
+                if (use_logit_softcap) {
+                    sum = logit_softcap*tanhf(sum);
+                }
+
                 sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
 
                 if (ncols == 1) {
@@ -286,10 +298,10 @@ static __global__ void flash_attn_vec_ext_f16(
 #endif // FP16_AVAILABLE
 }
 
-template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V>
+template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
 void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     constexpr int nwarps = D/WARP_SIZE;
-    fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V>;
+    fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
     constexpr bool need_f16_K = D != 128;
     constexpr bool need_f16_V = D != 128 && D != 64;
     launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
@@ -297,48 +309,81 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx,
 
 template <int D, ggml_type type_K, ggml_type type_V>
 void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    ggml_tensor * KQV = dst;
-    ggml_tensor * Q   = dst->src[0];
-    ggml_tensor * K   = dst->src[1];
-    ggml_tensor * V   = dst->src[2];
+    const ggml_tensor * KQV = dst;
+    const ggml_tensor * Q   = dst->src[0];
+    const ggml_tensor * K   = dst->src[1];
+    const ggml_tensor * V   = dst->src[2];
 
-    const int32_t precision = KQV->op_params[2];
+    const int32_t precision = KQV->op_params[3];
     GGML_ASSERT(precision == GGML_PREC_DEFAULT);
 
     GGML_ASSERT(K->type == type_K);
     GGML_ASSERT(V->type == type_V);
 
+    float logit_softcap;
+    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
     if (Q->ne[1] == 1) {
         constexpr int cols_per_block  = 1;
         constexpr int parallel_blocks = 4;
-        ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+        }
         return;
     }
 
     if (Q->ne[1] == 2) {
         constexpr int cols_per_block  = 2;
         constexpr int parallel_blocks = 4;
-        ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+        }
         return;
     }
 
     if (Q->ne[1] <= 4) {
         constexpr int cols_per_block  = 4;
         constexpr int parallel_blocks = 4;
-        ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+        }
         return;
     }
 
     if (Q->ne[1] <= 8) {
         constexpr int cols_per_block  = 8;
         constexpr int parallel_blocks = 4;
-        ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+        }
         return;
     }
 
     constexpr int cols_per_block  = 8;
     constexpr int parallel_blocks = 1;
-    ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+    if (logit_softcap == 0.0f) {
+        constexpr bool use_logit_softcap = false;
+        ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+    } else {
+        constexpr bool use_logit_softcap = true;
+        ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+    }
 }
 
 #define DECL_FATTN_VEC_F16_CASE(D, type_K, type_V)                          \
index 11a5e355fd40b8d59c9bd9176112505392dfdba5..bf5125902534e67d2dd001dfef129f74fcdc465d 100644 (file)
@@ -1,7 +1,7 @@
 #include "common.cuh"
 #include "fattn-common.cuh"
 
-template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V> // D == head size
+template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
 #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
 __launch_bounds__(D, 1)
 #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
@@ -17,6 +17,7 @@ static __global__ void flash_attn_vec_ext_f32(
         const float m0,
         const float m1,
         const uint32_t n_head_log2,
+        const float logit_softcap,
         const int ne00,
         const int ne01,
         const int ne02,
@@ -40,6 +41,12 @@ static __global__ void flash_attn_vec_ext_f32(
         const int ne1,
         const int ne2,
         const int ne3) {
+    // Skip unused kernel variants for faster compilation:
+    if (use_logit_softcap && !(D == 128 || D == 256)) {
+        NO_DEVICE_CODE;
+        return;
+    }
+
     //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
 
     constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32<D>(type_K);
@@ -180,6 +187,11 @@ static __global__ void flash_attn_vec_ext_f32(
             for (int j = 0; j < ncols; ++j) {
                 float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
                 sum = warp_reduce_sum(sum);
+
+                if (use_logit_softcap) {
+                    sum = logit_softcap*tanhf(sum);
+                }
+
                 sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
 
                 kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum);
@@ -267,10 +279,10 @@ static __global__ void flash_attn_vec_ext_f32(
     }
 }
 
-template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V>
+template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
 void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     constexpr int nwarps = D/WARP_SIZE;
-    fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V>;
+    fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
     constexpr bool need_f16_K = D != 128;
     constexpr bool need_f16_V = D != 128 && D != 64;
     launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
@@ -278,44 +290,78 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx,
 
 template <int D, ggml_type type_K, ggml_type type_V>
 void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    ggml_tensor * Q   = dst->src[0];
-    ggml_tensor * K   = dst->src[1];
-    ggml_tensor * V   = dst->src[2];
+    const ggml_tensor * KQV = dst;
+    const ggml_tensor * Q   = dst->src[0];
+    const ggml_tensor * K   = dst->src[1];
+    const ggml_tensor * V   = dst->src[2];
 
     GGML_ASSERT(K->type == type_K);
     GGML_ASSERT(V->type == type_V);
 
+    float logit_softcap;
+    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
     if (Q->ne[1] == 1) {
         constexpr int cols_per_block  = 1;
         constexpr int parallel_blocks = 4;
-        ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+        }
         return;
     }
 
     if (Q->ne[1] == 2) {
         constexpr int cols_per_block  = 2;
         constexpr int parallel_blocks = 4;
-        ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+        }
         return;
     }
 
     if (Q->ne[1] <= 4) {
         constexpr int cols_per_block  = 4;
         constexpr int parallel_blocks = 4;
-        ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+        }
         return;
     }
 
     if (Q->ne[1] <= 8) {
         constexpr int cols_per_block  = 8;
         constexpr int parallel_blocks = 4;
-        ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+        }
         return;
     }
 
     constexpr int cols_per_block  = 8;
     constexpr int parallel_blocks = 1;
-    ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+    if (logit_softcap == 0.0f) {
+        constexpr bool use_logit_softcap = false;
+        ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+    } else {
+        constexpr bool use_logit_softcap = true;
+        ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+    }
 }
 
 #define DECL_FATTN_VEC_F32_CASE(D, type_K, type_V)                          \
index ae232224260972fb23b0b2d4f8f52d693df5faab..b10d19d9327c9d3af13d939bea8fda2e6884e322 100644 (file)
@@ -6,7 +6,7 @@
 #endif // FP16_MMA_AVAILABLE
 
 // D == head size, VKQ_stride == num VKQ rows calculated in parallel:
-template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
+template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
 #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
 __launch_bounds__(nwarps*WARP_SIZE, 1)
 #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
@@ -22,6 +22,7 @@ static __global__ void flash_attn_ext_f16(
         const float m0,
         const float m1,
         const uint32_t n_head_log2,
+        const float logit_softcap,
         const int ne00,
         const int ne01,
         const int ne02,
@@ -46,6 +47,12 @@ static __global__ void flash_attn_ext_f16(
         const int ne2,
         const int ne3) {
 #ifdef FP16_MMA_AVAILABLE
+    // Skip unused kernel variants for faster compilation:
+    if (use_logit_softcap && !(D == 128 || D == 256)) {
+        NO_DEVICE_CODE;
+        return;
+    }
+
     //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
 
     const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
@@ -85,6 +92,8 @@ static __global__ void flash_attn_ext_f16(
     const half  slopeh = __float2half(slopef);
     const half2 slope2 = make_half2(slopef, slopef);
 
+    const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap);
+
     frag_b Q_b[D/16][ncols/frag_n];
 
     // A single buffer for temporarily holding tiles of KQ and VKQ parts:
@@ -194,6 +203,10 @@ static __global__ void flash_attn_ext_f16(
                     const int k = k0 + threadIdx.x;
 
                     KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
+
+                    if (use_logit_softcap) {
+                        KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]);
+                    }
                 }
 
                 float KQ_max_new = KQ_max_f[j0/nwarps];
@@ -237,6 +250,15 @@ static __global__ void flash_attn_ext_f16(
                     const int k = k0 + threadIdx.x;
 
                     KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
+
+                    if (use_logit_softcap) {
+                        // There is no dedicated tangens hyperbolicus function for half2.
+                        KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f));
+                        KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f))
+                                               /(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f));
+
+                        KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2;
+                    }
                 }
 
                 half2 KQ_max_new = KQ_max_h2[j0/nwarps];
@@ -427,7 +449,8 @@ static_assert(get_VKQ_stride( 80, 4, 16) ==  16, "Test failed.");
 
 template <int D, int cols_per_block, typename KQ_acc_t>
 void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * Q = dst->src[0];
+    const ggml_tensor * KQV = dst;
+    const ggml_tensor * Q   = dst->src[0];
 
     constexpr int nwarps = 4;
 
@@ -435,20 +458,50 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
     const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
     const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
 
+    float logit_softcap;
+    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
     if (4*blocks_num_pb1 < 2*nsm) {
         constexpr int parallel_blocks = 4;
-        fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
+        fattn_kernel_t fattn_kernel;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            fattn_kernel = flash_attn_ext_f16<
+                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+        } else {
+            constexpr bool use_logit_softcap = true;
+            fattn_kernel = flash_attn_ext_f16<
+                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+        }
         launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
         return;
     }
     if (2*blocks_num_pb1 < 2*nsm) {
         constexpr int parallel_blocks = 2;
-        fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
+        fattn_kernel_t fattn_kernel;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            fattn_kernel = flash_attn_ext_f16<
+                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+        } else {
+            constexpr bool use_logit_softcap = true;
+            fattn_kernel = flash_attn_ext_f16<
+                D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+        }
         launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
         return;
     }
     constexpr int parallel_blocks = 1;
-    fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
+    fattn_kernel_t fattn_kernel;
+    if (logit_softcap == 0.0f) {
+        constexpr bool use_logit_softcap = false;
+        fattn_kernel = flash_attn_ext_f16<
+            D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+    } else {
+        constexpr bool use_logit_softcap = true;
+        fattn_kernel = flash_attn_ext_f16<
+            D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+    }
     launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
 }
 
index 29f608b0ff98d0baa8c10c4f8cb7b9f85b131ff1..f87f33b3e574b65af5a8ceaff958fc98c7c69c57 100644 (file)
@@ -13,7 +13,7 @@ static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, g
     const ggml_tensor * KQV = dst;
     const ggml_tensor * Q   = dst->src[0];
 
-    const int32_t precision = KQV->op_params[2];
+    const int32_t precision = KQV->op_params[3];
 
     if (precision != GGML_PREC_DEFAULT) {
         if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
@@ -301,7 +301,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
 
     ggml_cuda_set_device(ctx.device);
     const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
-    const int32_t precision = KQV->op_params[2];
+    const int32_t precision = KQV->op_params[3];
 
     // On AMD the tile kernels perform poorly, use the vec kernel instead:
     if (cc >= CC_OFFSET_AMD) {
index 7dcb69ee770d3374244eb43854af5ae279dafd0d..f6c36267c51fb38ea03ca5d56868c9e6535aaf0c 100644 (file)
@@ -817,6 +817,15 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
             if (op->src[0]->ne[0] == 256) {
                 return false;
             }
+            {
+                float logit_softcap;
+
+                memcpy(&logit_softcap, ((const float *) op->op_params) + 2, sizeof(logit_softcap));
+
+                if (logit_softcap != 0.0f) {
+                    return false;
+                }
+            }
             return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
         case GGML_OP_MUL_MAT:
         case GGML_OP_MUL_MAT_ID:
index a9f03855725d078551e30ea1f3905c2be4840f1f..dea4c6b802769c953a753bf60bacbf241ba4a682 100644 (file)
@@ -7250,7 +7250,8 @@ struct ggml_tensor * ggml_flash_attn_ext(
         struct ggml_tensor  * v,
         struct ggml_tensor  * mask,
         float                 scale,
-        float                 max_bias) {
+        float                 max_bias,
+        float                 logit_softcap) {
     GGML_ASSERT(ggml_can_mul_mat(k, q));
     // TODO: check if vT can be multiplied by (k*qT)
 
@@ -7277,7 +7278,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
     int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
 
-    float params[] = { scale, max_bias };
+    float params[] = { scale, max_bias, logit_softcap };
     ggml_set_op_params(result, params, sizeof(params));
 
     result->op   = GGML_OP_FLASH_ATTN_EXT;
@@ -7297,7 +7298,7 @@ void ggml_flash_attn_ext_set_prec(
 
     const int32_t prec_i32 = (int32_t) prec;
 
-    ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
+    ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second
 }
 
 // ggml_flash_attn_back
@@ -15745,11 +15746,17 @@ static void ggml_compute_forward_flash_attn_ext_f16(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
-    float scale    = 1.0f;
-    float max_bias = 0.0f;
+    float scale         = 1.0f;
+    float max_bias      = 0.0f;
+    float logit_softcap = 0.0f;
 
-    memcpy(&scale,    (float *) dst->op_params + 0, sizeof(float));
-    memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
+    memcpy(&scale,         (float *) dst->op_params + 0, sizeof(float));
+    memcpy(&max_bias,      (float *) dst->op_params + 1, sizeof(float));
+    memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
+
+    if (logit_softcap != 0) {
+        scale /= logit_softcap;
+    }
 
     const uint32_t n_head      = neq2;
     const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
@@ -15813,7 +15820,13 @@ static void ggml_compute_forward_flash_attn_ext_f16(
             const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
             kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
 
-            s = s*scale + mv; // scale KQ value and apply mask
+            s = s*scale; // scale KQ value
+
+            if (logit_softcap != 0.0f) {
+                s = logit_softcap*tanhf(s);
+            }
+
+            s += mv; // apply mask
 
             const float Mold = M;
 
@@ -15822,7 +15835,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
 
             const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
 
-            if (v->type== GGML_TYPE_F16) {
+            if (v->type == GGML_TYPE_F16) {
                 if (s > M) {
                     // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
                     M = s;
@@ -15889,7 +15902,7 @@ static void ggml_compute_forward_flash_attn_ext(
         const struct ggml_tensor * v,
         const struct ggml_tensor * mask,
         struct ggml_tensor * dst) {
-    switch (dst->op_params[2]) {
+    switch (dst->op_params[3]) {
         case GGML_PREC_DEFAULT:
         case GGML_PREC_F32:
             {
index 95037e77c180acf6acbc751242540b5f7ccd8a66..fce5419795d4e350607a57c05cd0bb0193f55384 100644 (file)
@@ -1704,19 +1704,20 @@ struct test_flash_attn_ext : public test_case {
     const bool mask; // use mask
 
     const float max_bias; // ALiBi
+    const float logit_softcap; // Gemma 2
 
     const ggml_type type_KV;
 
     std::string vars() override {
-        return VARS_TO_STR7(hs, nh, kv, nb, mask, max_bias, type_KV);
+        return VARS_TO_STR8(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV);
     }
 
     double max_nmse_err() override {
         return 5e-4;
     }
 
-    test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f, ggml_type type_KV = GGML_TYPE_F16)
-        : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), type_KV(type_KV) {}
+    test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16)
+        : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
@@ -1725,7 +1726,7 @@ struct test_flash_attn_ext : public test_case {
         ggml_tensor * k = ggml_new_tensor_4d(ctx, type_KV,       hs_padded, kv, nh, 1);
         ggml_tensor * v = ggml_new_tensor_4d(ctx, type_KV,       hs_padded, kv, nh, 1);
         ggml_tensor * m = mask ? ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1) : nullptr;
-        ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias);
+        ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias, logit_softcap);
         return out;
     }
 };
@@ -2512,11 +2513,14 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         for (bool mask : { true, false } ) {
             for (float max_bias : { 0.0f, 8.0f }) {
                 if (!mask && max_bias > 0.0f) continue;
-                for (int nh : { 32, }) {
-                    for (int kv : { 512, 1024, }) {
-                        for (int nb : { 1, 2, 4, 8, }) {
-                            for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
-                                test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, type_KV));
+                for (float logit_softcap : {0.0f, 10.0f}) {
+                    if (hs != 128 && logit_softcap != 0.0f) continue;
+                    for (int nh : { 32, }) {
+                        for (int kv : { 512, 1024, }) {
+                            for (int nb : { 1, 2, 4, 8, }) {
+                                for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
+                                    test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV));
+                                }
                             }
                         }
                     }