struct ggml_tensor * v,
struct ggml_tensor * mask,
float scale,
- float max_bias);
+ float max_bias,
+ float logit_softcap);
GGML_API void ggml_flash_attn_ext_set_prec(
struct ggml_tensor * a,
const float m0,
const float m1,
const uint32_t n_head_log2,
+ const float logit_softcap,
const int ne00,
const int ne01,
const int ne02,
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
const int shmem = 0;
- float scale = 1.0f;
- float max_bias = 0.0f;
+ float scale = 1.0f;
+ float max_bias = 0.0f;
+ float logit_softcap = 0.0f;
- memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
- memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
+ memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
+ memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
+ memcpy(&logit_softcap, (float *) KQV->op_params + 2, sizeof(float));
+
+ if (logit_softcap != 0.0f) {
+ scale /= logit_softcap;
+ }
const uint32_t n_head = Q->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
V_data,
mask ? ((const char *) mask->data) : nullptr,
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
- scale, max_bias, m0, m1, n_head_log2,
+ scale, max_bias, m0, m1, n_head_log2, logit_softcap,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
#define FATTN_KQ_STRIDE_TILE_F16 64
-template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
+template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
const float m0,
const float m1,
const uint32_t n_head_log2,
+ const float logit_softcap,
const int ne00,
const int ne01,
const int ne02,
const int ne2,
const int ne3) {
#ifdef FP16_AVAILABLE
+ // Skip unused kernel variants for faster compilation:
+ if (use_logit_softcap && !(D == 128 || D == 256)) {
+ NO_DEVICE_CODE;
+ return;
+ }
+
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
const int j_KQ = j_KQ_0 + threadIdx.y;
- half sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
+ half sum;
+ if (use_logit_softcap) {
+ const float2 tmp = __half22float2(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
+ sum = logit_softcap * tanhf(tmp.x + tmp.y);
+ } else {
+ sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
+ }
sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum);
#endif // FP16_AVAILABLE
}
-template <int cols_per_block, int parallel_blocks>
+template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
case 64: {
constexpr int D = 64;
constexpr int nwarps = 8;
- fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
} break;
case 128: {
constexpr int D = 128;
constexpr int nwarps = 8;
- fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
} break;
default: {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
- const int32_t precision = KQV->op_params[2];
+ const int32_t precision = KQV->op_params[3];
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
+ float logit_softcap;
+ memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
if (Q->ne[1] <= 16) {
constexpr int cols_per_block = 16;
constexpr int parallel_blocks = 4;
- launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
+ if (logit_softcap == 0.0f) {
+ constexpr bool use_logit_softcap = false;
+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+ } else {
+ constexpr bool use_logit_softcap = true;
+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+ }
return;
}
if (Q->ne[1] <= 32) {
constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 4;
- launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
+ if (logit_softcap == 0.0f) {
+ constexpr bool use_logit_softcap = false;
+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+ } else {
+ constexpr bool use_logit_softcap = true;
+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+ }
return;
}
constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 1;
- launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
+ if (logit_softcap == 0.0f) {
+ constexpr bool use_logit_softcap = false;
+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+ } else {
+ constexpr bool use_logit_softcap = true;
+ launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+ }
}
#define FATTN_KQ_STRIDE_TILE_F32 32
-template<int D, int ncols, int nwarps, int parallel_blocks> // D == head size
+template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
const float m0,
const float m1,
const uint32_t n_head_log2,
+ const float logit_softcap,
const int ne00,
const int ne01,
const int ne02,
const int ne1,
const int ne2,
const int ne3) {
+ // Skip unused kernel variants for faster compilation:
+ if (use_logit_softcap && !(D == 128 || D == 256)) {
+ NO_DEVICE_CODE;
+ return;
+ }
+
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
const int j_KQ = j_KQ_0 + threadIdx.y;
+ if (use_logit_softcap) {
+ sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] = logit_softcap * tanhf(sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
+ }
+
sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
}
}
-template <int cols_per_block, int parallel_blocks>
+template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
case 64: {
constexpr int D = 64;
constexpr int nwarps = 8;
- fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
} break;
case 128: {
constexpr int D = 128;
constexpr int nwarps = 8;
- fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
+ fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
} break;
default: {
}
void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
+ float logit_softcap;
+ memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
if (Q->ne[1] <= 16) {
constexpr int cols_per_block = 16;
constexpr int parallel_blocks = 4;
- launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
+ if (logit_softcap == 0.0f) {
+ constexpr bool use_logit_softcap = false;
+ launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+ } else {
+ constexpr bool use_logit_softcap = true;
+ launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+ }
return;
}
if (Q->ne[1] <= 32) {
constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 4;
- launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
+ if (logit_softcap == 0.0f) {
+ constexpr bool use_logit_softcap = false;
+ launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+ } else {
+ constexpr bool use_logit_softcap = true;
+ launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+ }
return;
}
constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 1;
- launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
+ if (logit_softcap == 0.0f) {
+ constexpr bool use_logit_softcap = false;
+ launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+ } else {
+ constexpr bool use_logit_softcap = true;
+ launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
+ }
}
#include "common.cuh"
#include "fattn-common.cuh"
-template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V> // D == head size
+template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(D, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
const float m0,
const float m1,
const uint32_t n_head_log2,
+ const float logit_softcap,
const int ne00,
const int ne01,
const int ne02,
const int ne2,
const int ne3) {
#ifdef FP16_AVAILABLE
+ // Skip unused kernel variants for faster compilation:
+ if (use_logit_softcap && !(D == 128 || D == 256)) {
+ NO_DEVICE_CODE;
+ return;
+ }
+
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16<D>(type_K);
for (int j = 0; j < ncols; ++j) {
half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]);
sum = warp_reduce_sum(sum);
+
+ if (use_logit_softcap) {
+ sum = logit_softcap*tanhf(sum);
+ }
+
sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
if (ncols == 1) {
#endif // FP16_AVAILABLE
}
-template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V>
+template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
constexpr int nwarps = D/WARP_SIZE;
- fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V>;
+ fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
constexpr bool need_f16_K = D != 128;
constexpr bool need_f16_V = D != 128 && D != 64;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
template <int D, ggml_type type_K, ggml_type type_V>
void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
- ggml_tensor * KQV = dst;
- ggml_tensor * Q = dst->src[0];
- ggml_tensor * K = dst->src[1];
- ggml_tensor * V = dst->src[2];
+ const ggml_tensor * KQV = dst;
+ const ggml_tensor * Q = dst->src[0];
+ const ggml_tensor * K = dst->src[1];
+ const ggml_tensor * V = dst->src[2];
- const int32_t precision = KQV->op_params[2];
+ const int32_t precision = KQV->op_params[3];
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
GGML_ASSERT(K->type == type_K);
GGML_ASSERT(V->type == type_V);
+ float logit_softcap;
+ memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
if (Q->ne[1] == 1) {
constexpr int cols_per_block = 1;
constexpr int parallel_blocks = 4;
- ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ if (logit_softcap == 0.0f) {
+ constexpr bool use_logit_softcap = false;
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+ } else {
+ constexpr bool use_logit_softcap = true;
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+ }
return;
}
if (Q->ne[1] == 2) {
constexpr int cols_per_block = 2;
constexpr int parallel_blocks = 4;
- ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ if (logit_softcap == 0.0f) {
+ constexpr bool use_logit_softcap = false;
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+ } else {
+ constexpr bool use_logit_softcap = true;
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+ }
return;
}
if (Q->ne[1] <= 4) {
constexpr int cols_per_block = 4;
constexpr int parallel_blocks = 4;
- ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ if (logit_softcap == 0.0f) {
+ constexpr bool use_logit_softcap = false;
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+ } else {
+ constexpr bool use_logit_softcap = true;
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+ }
return;
}
if (Q->ne[1] <= 8) {
constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 4;
- ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ if (logit_softcap == 0.0f) {
+ constexpr bool use_logit_softcap = false;
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+ } else {
+ constexpr bool use_logit_softcap = true;
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+ }
return;
}
constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 1;
- ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ if (logit_softcap == 0.0f) {
+ constexpr bool use_logit_softcap = false;
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+ } else {
+ constexpr bool use_logit_softcap = true;
+ ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+ }
}
#define DECL_FATTN_VEC_F16_CASE(D, type_K, type_V) \
#include "common.cuh"
#include "fattn-common.cuh"
-template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V> // D == head size
+template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(D, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
const float m0,
const float m1,
const uint32_t n_head_log2,
+ const float logit_softcap,
const int ne00,
const int ne01,
const int ne02,
const int ne1,
const int ne2,
const int ne3) {
+ // Skip unused kernel variants for faster compilation:
+ if (use_logit_softcap && !(D == 128 || D == 256)) {
+ NO_DEVICE_CODE;
+ return;
+ }
+
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
constexpr vec_dot_KQ_f32_t vec_dot_KQ = get_vec_dot_KQ_f32<D>(type_K);
for (int j = 0; j < ncols; ++j) {
float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]);
sum = warp_reduce_sum(sum);
+
+ if (use_logit_softcap) {
+ sum = logit_softcap*tanhf(sum);
+ }
+
sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum);
}
}
-template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V>
+template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
constexpr int nwarps = D/WARP_SIZE;
- fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V>;
+ fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
constexpr bool need_f16_K = D != 128;
constexpr bool need_f16_V = D != 128 && D != 64;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
template <int D, ggml_type type_K, ggml_type type_V>
void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
- ggml_tensor * Q = dst->src[0];
- ggml_tensor * K = dst->src[1];
- ggml_tensor * V = dst->src[2];
+ const ggml_tensor * KQV = dst;
+ const ggml_tensor * Q = dst->src[0];
+ const ggml_tensor * K = dst->src[1];
+ const ggml_tensor * V = dst->src[2];
GGML_ASSERT(K->type == type_K);
GGML_ASSERT(V->type == type_V);
+ float logit_softcap;
+ memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
if (Q->ne[1] == 1) {
constexpr int cols_per_block = 1;
constexpr int parallel_blocks = 4;
- ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ if (logit_softcap == 0.0f) {
+ constexpr bool use_logit_softcap = false;
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+ } else {
+ constexpr bool use_logit_softcap = true;
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+ }
return;
}
if (Q->ne[1] == 2) {
constexpr int cols_per_block = 2;
constexpr int parallel_blocks = 4;
- ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ if (logit_softcap == 0.0f) {
+ constexpr bool use_logit_softcap = false;
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+ } else {
+ constexpr bool use_logit_softcap = true;
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+ }
return;
}
if (Q->ne[1] <= 4) {
constexpr int cols_per_block = 4;
constexpr int parallel_blocks = 4;
- ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ if (logit_softcap == 0.0f) {
+ constexpr bool use_logit_softcap = false;
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+ } else {
+ constexpr bool use_logit_softcap = true;
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+ }
return;
}
if (Q->ne[1] <= 8) {
constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 4;
- ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ if (logit_softcap == 0.0f) {
+ constexpr bool use_logit_softcap = false;
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+ } else {
+ constexpr bool use_logit_softcap = true;
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+ }
return;
}
constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 1;
- ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V>(ctx, dst);
+ if (logit_softcap == 0.0f) {
+ constexpr bool use_logit_softcap = false;
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+ } else {
+ constexpr bool use_logit_softcap = true;
+ ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
+ }
}
#define DECL_FATTN_VEC_F32_CASE(D, type_K, type_V) \
#endif // FP16_MMA_AVAILABLE
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
-template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
+template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
const float m0,
const float m1,
const uint32_t n_head_log2,
+ const float logit_softcap,
const int ne00,
const int ne01,
const int ne02,
const int ne2,
const int ne3) {
#ifdef FP16_MMA_AVAILABLE
+ // Skip unused kernel variants for faster compilation:
+ if (use_logit_softcap && !(D == 128 || D == 256)) {
+ NO_DEVICE_CODE;
+ return;
+ }
+
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
const half slopeh = __float2half(slopef);
const half2 slope2 = make_half2(slopef, slopef);
+ const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap);
+
frag_b Q_b[D/16][ncols/frag_n];
// A single buffer for temporarily holding tiles of KQ and VKQ parts:
const int k = k0 + threadIdx.x;
KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
+
+ if (use_logit_softcap) {
+ KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]);
+ }
}
float KQ_max_new = KQ_max_f[j0/nwarps];
const int k = k0 + threadIdx.x;
KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
+
+ if (use_logit_softcap) {
+ // There is no dedicated tangens hyperbolicus function for half2.
+ KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f));
+ KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f))
+ /(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f));
+
+ KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2;
+ }
}
half2 KQ_max_new = KQ_max_h2[j0/nwarps];
template <int D, int cols_per_block, typename KQ_acc_t>
void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
- const ggml_tensor * Q = dst->src[0];
+ const ggml_tensor * KQV = dst;
+ const ggml_tensor * Q = dst->src[0];
constexpr int nwarps = 4;
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
+ float logit_softcap;
+ memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
if (4*blocks_num_pb1 < 2*nsm) {
constexpr int parallel_blocks = 4;
- fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
+ fattn_kernel_t fattn_kernel;
+ if (logit_softcap == 0.0f) {
+ constexpr bool use_logit_softcap = false;
+ fattn_kernel = flash_attn_ext_f16<
+ D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+ } else {
+ constexpr bool use_logit_softcap = true;
+ fattn_kernel = flash_attn_ext_f16<
+ D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+ }
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
return;
}
if (2*blocks_num_pb1 < 2*nsm) {
constexpr int parallel_blocks = 2;
- fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
+ fattn_kernel_t fattn_kernel;
+ if (logit_softcap == 0.0f) {
+ constexpr bool use_logit_softcap = false;
+ fattn_kernel = flash_attn_ext_f16<
+ D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+ } else {
+ constexpr bool use_logit_softcap = true;
+ fattn_kernel = flash_attn_ext_f16<
+ D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+ }
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
return;
}
constexpr int parallel_blocks = 1;
- fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
+ fattn_kernel_t fattn_kernel;
+ if (logit_softcap == 0.0f) {
+ constexpr bool use_logit_softcap = false;
+ fattn_kernel = flash_attn_ext_f16<
+ D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+ } else {
+ constexpr bool use_logit_softcap = true;
+ fattn_kernel = flash_attn_ext_f16<
+ D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
+ }
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
}
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
- const int32_t precision = KQV->op_params[2];
+ const int32_t precision = KQV->op_params[3];
if (precision != GGML_PREC_DEFAULT) {
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
ggml_cuda_set_device(ctx.device);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
- const int32_t precision = KQV->op_params[2];
+ const int32_t precision = KQV->op_params[3];
// On AMD the tile kernels perform poorly, use the vec kernel instead:
if (cc >= CC_OFFSET_AMD) {
if (op->src[0]->ne[0] == 256) {
return false;
}
+ {
+ float logit_softcap;
+
+ memcpy(&logit_softcap, ((const float *) op->op_params) + 2, sizeof(logit_softcap));
+
+ if (logit_softcap != 0.0f) {
+ return false;
+ }
+ }
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
struct ggml_tensor * v,
struct ggml_tensor * mask,
float scale,
- float max_bias) {
+ float max_bias,
+ float logit_softcap) {
GGML_ASSERT(ggml_can_mul_mat(k, q));
// TODO: check if vT can be multiplied by (k*qT)
int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
- float params[] = { scale, max_bias };
+ float params[] = { scale, max_bias, logit_softcap };
ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_FLASH_ATTN_EXT;
const int32_t prec_i32 = (int32_t) prec;
- ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
+ ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second
}
// ggml_flash_attn_back
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
- float scale = 1.0f;
- float max_bias = 0.0f;
+ float scale = 1.0f;
+ float max_bias = 0.0f;
+ float logit_softcap = 0.0f;
- memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
- memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
+ memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
+ memcpy(&logit_softcap, (float *) dst->op_params + 2, sizeof(float));
+
+ if (logit_softcap != 0) {
+ scale /= logit_softcap;
+ }
const uint32_t n_head = neq2;
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
- s = s*scale + mv; // scale KQ value and apply mask
+ s = s*scale; // scale KQ value
+
+ if (logit_softcap != 0.0f) {
+ s = logit_softcap*tanhf(s);
+ }
+
+ s += mv; // apply mask
const float Mold = M;
const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
- if (v->type== GGML_TYPE_F16) {
+ if (v->type == GGML_TYPE_F16) {
if (s > M) {
// s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
M = s;
const struct ggml_tensor * v,
const struct ggml_tensor * mask,
struct ggml_tensor * dst) {
- switch (dst->op_params[2]) {
+ switch (dst->op_params[3]) {
case GGML_PREC_DEFAULT:
case GGML_PREC_F32:
{
const bool mask; // use mask
const float max_bias; // ALiBi
+ const float logit_softcap; // Gemma 2
const ggml_type type_KV;
std::string vars() override {
- return VARS_TO_STR7(hs, nh, kv, nb, mask, max_bias, type_KV);
+ return VARS_TO_STR8(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV);
}
double max_nmse_err() override {
return 5e-4;
}
- test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f, ggml_type type_KV = GGML_TYPE_F16)
- : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), type_KV(type_KV) {}
+ test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16)
+ : hs(hs), nh(nh), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV));
ggml_tensor * k = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1);
ggml_tensor * v = ggml_new_tensor_4d(ctx, type_KV, hs_padded, kv, nh, 1);
ggml_tensor * m = mask ? ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1) : nullptr;
- ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias);
+ ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias, logit_softcap);
return out;
}
};
for (bool mask : { true, false } ) {
for (float max_bias : { 0.0f, 8.0f }) {
if (!mask && max_bias > 0.0f) continue;
- for (int nh : { 32, }) {
- for (int kv : { 512, 1024, }) {
- for (int nb : { 1, 2, 4, 8, }) {
- for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
- test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, type_KV));
+ for (float logit_softcap : {0.0f, 10.0f}) {
+ if (hs != 128 && logit_softcap != 0.0f) continue;
+ for (int nh : { 32, }) {
+ for (int kv : { 512, 1024, }) {
+ for (int nb : { 1, 2, 4, 8, }) {
+ for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
+ test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV));
+ }
}
}
}