#include "fattn-common.cuh"
template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#ifndef GGML_USE_HIP
__launch_bounds__(D, 1)
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#endif // GGML_USE_HIP
static __global__ void flash_attn_vec_ext_f16(
const char * __restrict__ Q,
const char * __restrict__ K,
NO_DEVICE_CODE;
return;
}
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+ if (ncols > 1) {
+ NO_DEVICE_CODE;
+ return;
+ }
+#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
kqsum_shared[j][threadIdx.x] = 0.0f;
}
}
+
+ __shared__ half maskh_shared[ncols*D];
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ maskh_shared[j*D + tid] = 0.0f;
+ }
+
__syncthreads();
// Convert Q to half2 (f16 K) or q8_1 (quantized K) and store in registers:
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
// Calculate KQ tile and keep track of new maximum KQ values:
+ if (mask) {
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + k_VKQ_0 + tid];
+ }
+
+ __syncthreads();
+
+ // When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out.
+ // In such cases, skip the KV slice.
+ // On AMD __all_sync would not work correctly because it assumes a warp size of 64.
+#ifndef GGML_USE_HIP
+ bool skip = true;
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+#pragma unroll
+ for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+
+ const float2 tmp = __half22float2(((const half2 *) maskh_shared)[j*(D/2) + i]);
+ skip = skip && isinf(tmp.x) && isinf(tmp.y);
+ }
+ }
+ if (__all_sync(0xFFFFFFFF, skip)) {
+ continue;
+ }
+#endif // GGML_USE_HIP
+ }
+
// For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
// see https://github.com/ggerganov/llama.cpp/pull/7061 .
// Therefore this variable is defined twice but only used once (so that the compiler can optimize out the unused variable).
sum = logit_softcap*tanhf(sum);
}
- sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
+ sum += maskh_shared[j*D + i_KQ];
if (ncols == 1) {
kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
- if (Q->ne[1] == 1) {
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+
+ if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) {
constexpr int cols_per_block = 1;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
#include "fattn-common.cuh"
template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#ifndef GGML_USE_HIP
__launch_bounds__(D, 1)
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#endif // GGML_USE_HIP
static __global__ void flash_attn_vec_ext_f32(
const char * __restrict__ Q,
const char * __restrict__ K,
NO_DEVICE_CODE;
return;
}
+#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
+ if (ncols > 1) {
+ NO_DEVICE_CODE;
+ return;
+ }
+#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
kqsum_shared[j][threadIdx.x] = 0.0f;
}
}
+
+ __shared__ float maskf_shared[ncols*D];
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ maskf_shared[j*D + tid] = 0.0f;
+ }
+
__syncthreads();
// Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
// Calculate KQ tile and keep track of new maximum KQ values:
+ if (mask) {
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + k_VKQ_0 + tid]);
+ }
+
+ __syncthreads();
+
+ // When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out.
+ // In such cases, skip the KV slice.
+ // On AMD __all_sync would not work correctly because it assumes a warp size of 64.
+#ifndef GGML_USE_HIP
+ bool skip = true;
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+#pragma unroll
+ for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
+ const int i = i0 + threadIdx.x;
+
+ skip = skip && isinf(maskf_shared[j*D + i]);
+ }
+ }
+ if (__all_sync(0xFFFFFFFF, skip)) {
+ continue;
+ }
+#endif // GGML_USE_HIP
+ }
+
float kqmax_new_arr[ncols];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
sum = logit_softcap*tanhf(sum);
}
- sum += mask ? slope*__half2float(maskh[j*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
+ sum += maskf_shared[j*D + i_KQ];
kqmax_new_arr[j] = fmaxf(kqmax_new_arr[j], sum);
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
- if (Q->ne[1] == 1) {
+ const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
+
+ if (Q->ne[1] == 1 || GGML_CUDA_CC_IS_NVIDIA(cc)) {
constexpr int cols_per_block = 1;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;