]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: deduplicate FlashAttention code (llama/7352)
authorJohannes Gäßler <redacted>
Sat, 18 May 2024 10:36:25 +0000 (12:36 +0200)
committerGeorgi Gerganov <redacted>
Tue, 28 May 2024 11:41:08 +0000 (14:41 +0300)
src/ggml-cuda/common.cuh
src/ggml-cuda/fattn-common.cuh
src/ggml-cuda/fattn-tile-f16.cu
src/ggml-cuda/fattn-tile-f32.cu
src/ggml-cuda/fattn-vec-f16.cu
src/ggml-cuda/fattn-vec-f32.cu
src/ggml-cuda/fattn.cu
src/ggml-cuda/softmax.cu

index 784792ba0dfcda65382bd1df4d466b7cd5e398fe..8f6fd71cfea35eb22ea948281fee0e1acbbd9c4f 100644 (file)
@@ -477,6 +477,17 @@ static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -
 
 typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
 
+static __device__ __forceinline__ float get_alibi_slope(
+    const float max_bias, const uint32_t h, const uint32_t n_head_log2, const float m0, const float m1
+) {
+    if (max_bias <= 0.0f) {
+        return 1.0f;
+    }
+    const float base = h < n_head_log2 ? m0 : m1;
+    const int   exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+    return powf(base, exph);
+}
 
 //////////////////////
 
index 33f640691ad7a0b4037e6ae37b1a7135bcfb2bca..1dd519bdee7f1c18a027b55a8aff98c6736f5d42 100644 (file)
@@ -1,7 +1,44 @@
+#include "common.cuh"
+
+#include <cstdint>
+
 #define FATTN_KQ_STRIDE       256
 #define HALF_MAX_HALF         __float2half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
 #define SOFTMAX_FTZ_THRESHOLD -20.0f                   // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
 
+typedef void (* fattn_kernel_t)(
+        const char * __restrict__ Q,
+        const char * __restrict__ K,
+        const char * __restrict__ V,
+        const char * __restrict__ mask,
+        float      * __restrict__ dst,
+        float2     * __restrict__ dst_meta,
+        const float scale,
+        const float max_bias,
+        const float m0,
+        const float m1,
+        const uint32_t n_head_log2,
+        const int ne00,
+        const int ne01,
+        const int ne02,
+        const int ne03,
+        const int ne10,
+        const int ne11,
+        const int ne12,
+        const int ne13,
+        const int ne31,
+        const int nb31,
+        const int nb01,
+        const int nb02,
+        const int nb03,
+        const int nb11,
+        const int nb12,
+        const int nb13,
+        const int ne0,
+        const int ne1,
+        const int ne2,
+        const int ne3);
+
 template<int D, int parallel_blocks> // D == head size
 #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
 __launch_bounds__(D, 1)
@@ -45,3 +82,81 @@ static __global__ void flash_attn_combine_results(
 
     dst[blockIdx.y*D + tid] = VKQ_numerator / VKQ_denominator;
 }
+
+template <int D, int parallel_blocks>
+void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, int nwarps, int cols_per_block) {
+    const ggml_tensor * Q = dst->src[0];
+    const ggml_tensor * K = dst->src[1];
+    const ggml_tensor * V = dst->src[2];
+
+    const ggml_tensor * mask = dst->src[3];
+
+    ggml_tensor * KQV = dst;
+
+    GGML_ASSERT(Q->type == GGML_TYPE_F32);
+    GGML_ASSERT(K->type == GGML_TYPE_F16);
+    GGML_ASSERT(V->type == GGML_TYPE_F16);
+    GGML_ASSERT(KQV->type == GGML_TYPE_F32);
+
+    GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
+    GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
+                                "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
+
+    GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
+
+    ggml_cuda_pool & pool = ctx.pool();
+    cudaStream_t main_stream = ctx.stream();
+
+    ggml_cuda_pool_alloc<float>  dst_tmp(pool);
+    ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
+
+    if (parallel_blocks > 1) {
+        dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
+        dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
+    }
+
+    const dim3 block_dim(WARP_SIZE, nwarps, 1);
+    const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
+    const int  shmem = 0;
+
+    float scale    = 1.0f;
+    float max_bias = 0.0f;
+
+    memcpy(&scale,    (float *) KQV->op_params + 0, sizeof(float));
+    memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
+
+    const uint32_t n_head      = Q->ne[2];
+    const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+
+    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+    fattn_kernel<<<blocks_num, block_dim, shmem, main_stream>>>(
+        (const char *) Q->data,
+        (const char *) K->data,
+        (const char *) V->data,
+        mask ? ((const char *) mask->data) : nullptr,
+        (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
+        scale, max_bias, m0, m1, n_head_log2,
+        Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
+        K->ne[0], K->ne[1], K->ne[2], K->ne[3],
+        mask ? mask->ne[1] : 0, mask ?  mask->nb[1] : 0,
+        Q->nb[1], Q->nb[2], Q->nb[3],
+        K->nb[1], K->nb[2], K->nb[3],
+        KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
+    );
+    CUDA_CHECK(cudaGetLastError());
+
+    if ((parallel_blocks) == 1) {
+        return;
+    }
+
+    const dim3 block_dim_combine(D, 1, 1);
+    const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
+    const int  shmem_combine = 0;
+
+    flash_attn_combine_results<D, parallel_blocks>
+        <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
+        (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
+    CUDA_CHECK(cudaGetLastError());
+}
index d2a1077ed44cb600059c524abdc62f0d81565a0a..4a07ac6adad717870711400b20871a6066f76f73 100644 (file)
@@ -54,17 +54,8 @@ static __global__ void flash_attn_tile_ext_f16(
 
     const int stride_KV2 = nb11 / sizeof(half2);
 
-    half slopeh = __float2half(1.0f);
-
-    // ALiBi
-    if (max_bias > 0.0f) {
-        const uint32_t h = blockIdx.y;
-
-        const float base = h < n_head_log2 ? m0 : m1;
-        const int   exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
-
-        slopeh = __float2half(powf(base, exph));
-    }
+    const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
+    const half  slopeh = __float2half(slopef);
 
     static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
 
@@ -272,124 +263,50 @@ static __global__ void flash_attn_tile_ext_f16(
 #endif // FP16_AVAILABLE
 }
 
-template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_tile_f16(
-        const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
-        ggml_cuda_pool & pool, cudaStream_t main_stream
-) {
-    ggml_cuda_pool_alloc<float>  dst_tmp(pool);
-    ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
-
-    if (parallel_blocks > 1) {
-        dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
-        dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
-    }
-
-    constexpr int  nwarps = 8;
-    const     dim3 block_dim(WARP_SIZE, nwarps, 1);
-    const     dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
-    const     int  shmem = 0;
-
-    float scale    = 1.0f;
-    float max_bias = 0.0f;
-
-    memcpy(&scale,    (float *) KQV->op_params + 0, sizeof(float));
-    memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
-
-    const uint32_t n_head      = Q->ne[2];
-    const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
-
-    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
-    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
-
-    flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>
-        <<<blocks_num, block_dim, shmem, main_stream>>> (
-                (const char *) Q->data,
-                (const char *) K->data,
-                (const char *) V->data,
-                mask ? ((const char *) mask->data) : nullptr,
-                parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
-                scale, max_bias, m0, m1, n_head_log2,
-                Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
-                K->ne[0], K->ne[1], K->ne[2], K->ne[3],
-                mask ? mask->ne[1] : 0, mask ?  mask->nb[1] : 0,
-                Q->nb[1], Q->nb[2], Q->nb[3],
-                K->nb[1], K->nb[2], K->nb[3],
-                KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
-                );
-    CUDA_CHECK(cudaGetLastError());
-
-    if (parallel_blocks == 1) {
-        return;
+template <int cols_per_block, int parallel_blocks>
+void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * Q = dst->src[0];
+    switch (Q->ne[0]) {
+        case  64: {
+            constexpr int      D = 64;
+            constexpr int nwarps = 8;
+            fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
+            launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
+        } break;
+        case 128: {
+            constexpr int      D = 128;
+            constexpr int nwarps = 8;
+            fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
+            launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
+        } break;
+        default: {
+            GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
+        } break;
     }
-
-    const dim3 block_dim_combine(D, 1, 1);
-    const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
-    const int  shmem_combine = 0;
-
-    flash_attn_combine_results<D, parallel_blocks>
-        <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
-        (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
-    CUDA_CHECK(cudaGetLastError());
 }
 
 void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * Q = dst->src[0];
-    const ggml_tensor * K = dst->src[1];
-    const ggml_tensor * V = dst->src[2];
-
-    const ggml_tensor * mask = dst->src[3];
-
-    ggml_tensor * KQV = dst;
+    const ggml_tensor * KQV = dst;
+    const ggml_tensor * Q   = dst->src[0];
 
     const int32_t precision = KQV->op_params[2];
     GGML_ASSERT(precision == GGML_PREC_DEFAULT);
-    GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
 
     if (Q->ne[1] <= 16) {
         constexpr int cols_per_block = 16;
         constexpr int parallel_blocks = 4;
-        switch (Q->ne[0]) {
-            case 64:
-                launch_fattn_tile_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
-                break;
-            case 128:
-                launch_fattn_tile_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
-                break;
-            default:
-                GGML_ASSERT(false);
-                break;
-        }
+        launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
         return;
     }
 
     if (Q->ne[1] <= 32) {
         constexpr int cols_per_block = 32;
         constexpr int parallel_blocks = 4;
-        switch (Q->ne[0]) {
-            case 64:
-                launch_fattn_tile_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
-                break;
-            case 128:
-                launch_fattn_tile_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
-                break;
-            default:
-                GGML_ASSERT(false);
-                break;
-        }
+        launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
         return;
     }
 
     constexpr int cols_per_block = 32;
     constexpr int parallel_blocks = 1;
-    switch (Q->ne[0]) {
-        case 64:
-            launch_fattn_tile_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
-            break;
-        case 128:
-            launch_fattn_tile_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
-            break;
-        default:
-            GGML_ASSERT(false);
-            break;
-    }
+    launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
 }
index 176895edd34f215488264f931adb5f50fe70c2b7..130e7cbdbe10dc44159bca2709ede4f2975ba3b6 100644 (file)
@@ -53,17 +53,7 @@ static __global__ void flash_attn_tile_ext_f32(
 
     const int stride_KV2 = nb11 / sizeof(half2);
 
-    float slope = 1.0f;
-
-    // ALiBi
-    if (max_bias > 0.0f) {
-        const uint32_t h = blockIdx.y;
-
-        const float base = h < n_head_log2 ? m0 : m1;
-        const int   exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
-
-        slope = powf(base, exph);
-    }
+    const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
 
     static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
 
@@ -270,124 +260,50 @@ static __global__ void flash_attn_tile_ext_f32(
     }
 }
 
-template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_tile_f32(
-        const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
-        ggml_cuda_pool & pool, cudaStream_t main_stream
-) {
-    ggml_cuda_pool_alloc<float>  dst_tmp(pool);
-    ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
-
-    if (parallel_blocks > 1) {
-        dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
-        dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
-    }
-
-    constexpr int  nwarps = 8;
-    const     dim3 block_dim(WARP_SIZE, nwarps, 1);
-    const     dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
-    const     int  shmem = 0;
-
-    float scale    = 1.0f;
-    float max_bias = 0.0f;
-
-    memcpy(&scale,    (float *) KQV->op_params + 0, sizeof(float));
-    memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
-
-    const uint32_t n_head      = Q->ne[2];
-    const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
-
-    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
-    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
-
-    flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>
-        <<<blocks_num, block_dim, shmem, main_stream>>> (
-                (const char *) Q->data,
-                (const char *) K->data,
-                (const char *) V->data,
-                mask ? ((const char *) mask->data) : nullptr,
-                parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
-                scale, max_bias, m0, m1, n_head_log2,
-                Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
-                K->ne[0], K->ne[1], K->ne[2], K->ne[3],
-                mask ? mask->ne[1] : 0, mask ?  mask->nb[1] : 0,
-                Q->nb[1], Q->nb[2], Q->nb[3],
-                K->nb[1], K->nb[2], K->nb[3],
-                KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
-                );
-    CUDA_CHECK(cudaGetLastError());
-
-    if (parallel_blocks == 1) {
-        return;
+template <int cols_per_block, int parallel_blocks>
+void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * Q = dst->src[0];
+    switch (Q->ne[0]) {
+        case  64: {
+            constexpr int      D = 64;
+            constexpr int nwarps = 8;
+            fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
+            launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
+        } break;
+        case 128: {
+            constexpr int      D = 128;
+            constexpr int nwarps = 8;
+            fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
+            launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
+        } break;
+        default: {
+            GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
+        } break;
     }
-
-    const dim3 block_dim_combine(D, 1, 1);
-    const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
-    const int  shmem_combine = 0;
-
-    flash_attn_combine_results<D, parallel_blocks>
-        <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
-        (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
-    CUDA_CHECK(cudaGetLastError());
 }
 
 void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * Q = dst->src[0];
-    const ggml_tensor * K = dst->src[1];
-    const ggml_tensor * V = dst->src[2];
-
-    const ggml_tensor * mask = dst->src[3];
-
-    ggml_tensor * KQV = dst;
+    const ggml_tensor * KQV = dst;
+    const ggml_tensor * Q   = dst->src[0];
 
     const int32_t precision = KQV->op_params[2];
     GGML_ASSERT(precision == GGML_PREC_DEFAULT);
-    GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
 
     if (Q->ne[1] <= 16) {
         constexpr int cols_per_block = 16;
         constexpr int parallel_blocks = 4;
-        switch (Q->ne[0]) {
-            case 64:
-                launch_fattn_tile_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
-                break;
-            case 128:
-                launch_fattn_tile_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
-                break;
-            default:
-                GGML_ASSERT(false);
-                break;
-        }
+        launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
         return;
     }
 
     if (Q->ne[1] <= 32) {
         constexpr int cols_per_block = 32;
         constexpr int parallel_blocks = 4;
-        switch (Q->ne[0]) {
-            case 64:
-                launch_fattn_tile_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
-                break;
-            case 128:
-                launch_fattn_tile_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
-                break;
-            default:
-                GGML_ASSERT(false);
-                break;
-        }
+        launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
         return;
     }
 
     constexpr int cols_per_block = 32;
     constexpr int parallel_blocks = 1;
-    switch (Q->ne[0]) {
-        case 64:
-            launch_fattn_tile_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
-            break;
-        case 128:
-            launch_fattn_tile_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
-            break;
-        default:
-            GGML_ASSERT(false);
-            break;
-    }
+    launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
 }
index a18be5ddcc3d34f37d9b321ad4dbf5a0660869f3..54e1ac5d1605010a03190dc02abb394b753ecfc0 100644 (file)
@@ -53,17 +53,8 @@ static __global__ void flash_attn_vec_ext_f16(
     const int stride_KV  = nb11 / sizeof(half);
     const int stride_KV2 = nb11 / sizeof(half2);
 
-    half  slopeh = __float2half(1.0f);
-
-    // ALiBi
-    if (max_bias > 0.0f) {
-        const uint32_t h = blockIdx.y;
-
-        const float base = h < n_head_log2 ? m0 : m1;
-        const int   exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
-
-        slopeh = __float2half(powf(base, exph));
-    }
+    const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
+    const half  slopeh = __float2half(slopef);
 
     static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
     constexpr int nwarps = D / WARP_SIZE;
@@ -232,196 +223,104 @@ static __global__ void flash_attn_vec_ext_f16(
         dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
     }
 
-    if (parallel_blocks != 1 && threadIdx.x < ncols) {
-        dst_meta[(ic0 + threadIdx.x)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[threadIdx.x], kqsum[threadIdx.x]);
+    if (parallel_blocks != 1 && tid < ncols) {
+        dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
     }
 #else
    NO_DEVICE_CODE;
 #endif // FP16_AVAILABLE
 }
 
-template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_f16(
-        const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
-        ggml_cuda_pool & pool, cudaStream_t main_stream
-) {
-    ggml_cuda_pool_alloc<float>  dst_tmp(pool);
-    ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
-
-    if (parallel_blocks > 1) {
-        dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
-        dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
-    }
-
-    constexpr int  nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
-    const     dim3 block_dim(WARP_SIZE, nwarps, 1);
-    const     dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
-    const     int  shmem = 0;
-
-    float scale    = 1.0f;
-    float max_bias = 0.0f;
-
-    memcpy(&scale,    (float *) KQV->op_params + 0, sizeof(float));
-    memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
-
-    const uint32_t n_head      = Q->ne[2];
-    const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
-
-    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
-    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
-
-    flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>
-        <<<blocks_num, block_dim, shmem, main_stream>>> (
-                (const char *) Q->data,
-                (const char *) K->data,
-                (const char *) V->data,
-                mask ? ((const char *) mask->data) : nullptr,
-                parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
-                scale, max_bias, m0, m1, n_head_log2,
-                Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
-                K->ne[0], K->ne[1], K->ne[2], K->ne[3],
-                mask ? mask->ne[1] : 0, mask ?  mask->nb[1] : 0,
-                Q->nb[1], Q->nb[2], Q->nb[3],
-                K->nb[1], K->nb[2], K->nb[3],
-                KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
-                );
-    CUDA_CHECK(cudaGetLastError());
-
-    if (parallel_blocks == 1) {
-        return;
-    }
-
-    const dim3 block_dim_combine(D, 1, 1);
-    const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
-    const int  shmem_combine = 0;
-
-    flash_attn_combine_results<D, parallel_blocks>
-        <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
-        (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
-    CUDA_CHECK(cudaGetLastError());
-}
-
 void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * Q = dst->src[0];
-    const ggml_tensor * K = dst->src[1];
-    const ggml_tensor * V = dst->src[2];
-
-    const ggml_tensor * mask = dst->src[3];
-
     ggml_tensor * KQV = dst;
+    ggml_tensor * Q   = dst->src[0];
 
     const int32_t precision = KQV->op_params[2];
     GGML_ASSERT(precision == GGML_PREC_DEFAULT);
 
-    constexpr int cols_per_block = 1;
+    constexpr int cols_per_block  = 1;
     constexpr int parallel_blocks = 4;
     switch (Q->ne[0]) {
-        case 64:
-            launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
-            break;
-        case 128:
-            launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
-            break;
-        case 256:
-            launch_fattn_vec_f16<256, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
-            break;
+        case  64: {
+            constexpr int      D = 64;
+            constexpr int nwarps = D/WARP_SIZE;
+            fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>;
+            launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
+        } break;
+        case 128: {
+            constexpr int      D = 128;
+            constexpr int nwarps = D/WARP_SIZE;
+            fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>;
+            launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
+        } break;
+        case 256: {
+            constexpr int      D = 256;
+            constexpr int nwarps = D/WARP_SIZE;
+            fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>;
+            launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
+        } break;
         default:
             GGML_ASSERT(false);
             break;
     }
 }
 
-void ggml_cuda_flash_attn_ext_vec_f16_no_mma(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+template <int cols_per_block, int parallel_blocks>
+void launch_fattn_vec_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * Q = dst->src[0];
-    const ggml_tensor * K = dst->src[1];
-    const ggml_tensor * V = dst->src[2];
-
-    const ggml_tensor * mask = dst->src[3];
+    switch (Q->ne[0]) {
+        case  64: {
+            constexpr int      D = 64;
+            constexpr int nwarps = D/WARP_SIZE;
+            fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>;
+            launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
+        } break;
+        case 128: {
+            constexpr int      D = 128;
+            constexpr int nwarps = D/WARP_SIZE;
+            fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>;
+            launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
+        } break;
+        default: {
+            GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
+        } break;
+    }
+}
 
-    ggml_tensor * KQV = dst;
+void ggml_cuda_flash_attn_ext_vec_f16_no_mma(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * KQV = dst;
+    const ggml_tensor * Q   = dst->src[0];
 
     const int32_t precision = KQV->op_params[2];
     GGML_ASSERT(precision == GGML_PREC_DEFAULT);
-    GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
 
     if (Q->ne[1] == 1) {
-        constexpr int cols_per_block = 1;
-        constexpr int parallel_blocks = 4;
-        switch (Q->ne[0]) {
-            case 64:
-                launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
-                break;
-            case 128:
-                launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
-                break;
-            default:
-                GGML_ASSERT(false);
-                break;
-        }
+        ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
         return;
     }
 
     if (Q->ne[1] == 2) {
-        constexpr int cols_per_block = 2;
+        constexpr int cols_per_block  = 2;
         constexpr int parallel_blocks = 4;
-        switch (Q->ne[0]) {
-            case 64:
-                launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
-                break;
-            case 128:
-                launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
-                break;
-            default:
-                GGML_ASSERT(false);
-                break;
-        }
+        launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
         return;
     }
 
     if (Q->ne[1] <= 4) {
-        constexpr int cols_per_block = 4;
+        constexpr int cols_per_block  = 4;
         constexpr int parallel_blocks = 4;
-        switch (Q->ne[0]) {
-            case 64:
-                launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
-                break;
-            case 128:
-                launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
-                break;
-            default:
-                GGML_ASSERT(false);
-                break;
-        }
+        launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
         return;
     }
 
     if (Q->ne[1] <= 8) {
-        constexpr int cols_per_block = 8;
+        constexpr int cols_per_block  = 8;
         constexpr int parallel_blocks = 4;
-        switch (Q->ne[0]) {
-            case 64:
-                launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
-                break;
-            case 128:
-                launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
-                break;
-            default:
-                GGML_ASSERT(false);
-                break;
-        }
+        launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
         return;
     }
 
-    constexpr int cols_per_block = 8;
+    constexpr int cols_per_block  = 8;
     constexpr int parallel_blocks = 1;
-    switch (Q->ne[0]) {
-        case 64:
-            launch_fattn_vec_f16< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
-            break;
-        case 128:
-            launch_fattn_vec_f16<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
-            break;
-        default:
-            GGML_ASSERT(false);
-            break;
-    }
+    launch_fattn_vec_f16_64_128<cols_per_block, parallel_blocks>(ctx, dst);
 }
index 91fcdc8c3542532c6618e4e283e215dbd20a3bb4..5bcabd0928451e8ebbaee831e5b7a7cfd131afa5 100644 (file)
@@ -52,17 +52,7 @@ static __global__ void flash_attn_vec_ext_f32(
     const int stride_KV  = nb11 / sizeof(half);
     const int stride_KV2 = nb11 / sizeof(half2);
 
-    float slope = 1.0f;
-
-    // ALiBi
-    if (max_bias > 0.0f) {
-        const uint32_t h = blockIdx.y;
-
-        const float base = h < n_head_log2 ? m0 : m1;
-        const int   exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
-
-        slope = powf(base, exph);
-    }
+    const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
 
     static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
     constexpr int nwarps = D / WARP_SIZE;
@@ -221,161 +211,65 @@ static __global__ void flash_attn_vec_ext_f32(
         dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
     }
 
-    if (parallel_blocks != 1 && threadIdx.x < ncols) {
-        dst_meta[(ic0 + threadIdx.x)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[threadIdx.x], kqsum[threadIdx.x]);
+    if (parallel_blocks != 1 && tid < ncols) {
+        dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
     }
 }
 
-template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_f32(
-        const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
-        ggml_cuda_pool & pool, cudaStream_t main_stream
-) {
-    ggml_cuda_pool_alloc<float>  dst_tmp(pool);
-    ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
-
-    if (parallel_blocks > 1) {
-        dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
-        dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
-    }
-
-    constexpr int  nwarps = (D + WARP_SIZE - 1) / WARP_SIZE;
-    const     dim3 block_dim(WARP_SIZE, nwarps, 1);
-    const     dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
-    const     int  shmem = 0;
-
-    float scale    = 1.0f;
-    float max_bias = 0.0f;
-
-    memcpy(&scale,    (float *) KQV->op_params + 0, sizeof(float));
-    memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
-
-    const uint32_t n_head      = Q->ne[2];
-    const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
-
-    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
-    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
-
-    flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks>
-        <<<blocks_num, block_dim, shmem, main_stream>>> (
-                (const char *) Q->data,
-                (const char *) K->data,
-                (const char *) V->data,
-                mask ? ((const char *) mask->data) : nullptr,
-                parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
-                scale, max_bias, m0, m1, n_head_log2,
-                Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
-                K->ne[0], K->ne[1], K->ne[2], K->ne[3],
-                mask ? mask->ne[1] : 0, mask ?  mask->nb[1] : 0,
-                Q->nb[1], Q->nb[2], Q->nb[3],
-                K->nb[1], K->nb[2], K->nb[3],
-                KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
-                );
-    CUDA_CHECK(cudaGetLastError());
-
-    if (parallel_blocks == 1) {
-        return;
+template <int cols_per_block, int parallel_blocks>
+void launch_fattn_vec_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * Q = dst->src[0];
+    switch (Q->ne[0]) {
+        case  64: {
+            constexpr int      D = 64;
+            constexpr int nwarps = D/WARP_SIZE;
+            fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks>;
+            launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
+        } break;
+        case 128: {
+            constexpr int      D = 128;
+            constexpr int nwarps = D/WARP_SIZE;
+            fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks>;
+            launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
+        } break;
+        default: {
+            GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
+        } break;
     }
-
-    const dim3 block_dim_combine(D, 1, 1);
-    const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
-    const int  shmem_combine = 0;
-
-    flash_attn_combine_results<D, parallel_blocks>
-        <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
-        (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
-    CUDA_CHECK(cudaGetLastError());
 }
 
 void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * Q = dst->src[0];
-    const ggml_tensor * K = dst->src[1];
-    const ggml_tensor * V = dst->src[2];
-
-    const ggml_tensor * mask = dst->src[3];
-
-    ggml_tensor * KQV = dst;
-
-    GGML_ASSERT(Q->ne[0] == 64 || Q->ne[0] == 128 && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
 
     if (Q->ne[1] == 1) {
-        constexpr int cols_per_block = 1;
+        constexpr int cols_per_block  = 1;
         constexpr int parallel_blocks = 4;
-        switch (Q->ne[0]) {
-            case 64:
-                launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
-                break;
-            case 128:
-                launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
-                break;
-            default:
-                GGML_ASSERT(false);
-                break;
-        }
+        launch_fattn_vec_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
         return;
     }
 
     if (Q->ne[1] == 2) {
-        constexpr int cols_per_block = 2;
+        constexpr int cols_per_block  = 2;
         constexpr int parallel_blocks = 4;
-        switch (Q->ne[0]) {
-            case 64:
-                launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
-                break;
-            case 128:
-                launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
-                break;
-            default:
-                GGML_ASSERT(false);
-                break;
-        }
+        launch_fattn_vec_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
         return;
     }
 
     if (Q->ne[1] <= 4) {
-        constexpr int cols_per_block = 4;
+        constexpr int cols_per_block  = 4;
         constexpr int parallel_blocks = 4;
-        switch (Q->ne[0]) {
-            case 64:
-                launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
-                break;
-            case 128:
-                launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
-                break;
-            default:
-                GGML_ASSERT(false);
-                break;
-        }
+        launch_fattn_vec_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
         return;
     }
 
     if (Q->ne[1] <= 8) {
-        constexpr int cols_per_block = 8;
+        constexpr int cols_per_block  = 8;
         constexpr int parallel_blocks = 4;
-        switch (Q->ne[0]) {
-            case 64:
-                launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
-                break;
-            case 128:
-                launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
-                break;
-            default:
-                GGML_ASSERT(false);
-                break;
-        }
+        launch_fattn_vec_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
         return;
     }
 
-    constexpr int cols_per_block = 8;
+    constexpr int cols_per_block  = 8;
     constexpr int parallel_blocks = 1;
-    switch (Q->ne[0]) {
-        case 64:
-            launch_fattn_vec_f32< 64, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
-            break;
-        case 128:
-            launch_fattn_vec_f32<128, cols_per_block, parallel_blocks>(Q, K, V, KQV, mask, ctx.pool(), ctx.stream());
-            break;
-        default:
-            GGML_ASSERT(false);
-            break;
-    }
+    launch_fattn_vec_f32_64_128<cols_per_block, parallel_blocks>(ctx, dst);
 }
index a1918e2587ddb7f903f29889561221efaec8e3e9..af7c95232ddf39e5e41ce676abaa5ca37e21332b 100644 (file)
@@ -85,19 +85,9 @@ static __global__ void flash_attn_ext_f16(
     const int stride_Q  = nb01 / sizeof(float);
     const int stride_KV = nb11 / sizeof(half);
 
-    half  slopeh = __float2half(1.0f);
-    half2 slope2 = make_half2(1.0f, 1.0f);
-
-    // ALiBi
-    if (max_bias > 0.0f) {
-        const uint32_t h = blockIdx.y;
-
-        const float base = h < n_head_log2 ? m0 : m1;
-        const int   exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
-
-        slopeh = __float2half(powf(base, exph));
-        slope2 = make_half2(slopeh, slopeh);
-    }
+    const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
+    const half  slopeh = __float2half(slopef);
+    const half2 slope2 = make_half2(slopef, slopef);
 
     frag_b Q_b[D/16][ncols/frag_n];
 
@@ -439,108 +429,37 @@ static_assert(get_VKQ_stride( 80, 1, 16) ==  16, "Test failed.");
 static_assert(get_VKQ_stride( 80, 2, 16) ==  16, "Test failed.");
 static_assert(get_VKQ_stride( 80, 4, 16) ==  16, "Test failed.");
 
-template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename KQ_acc_t> void launch_fattn_f16_impl(
-        const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
-        ggml_cuda_pool & pool, cudaStream_t main_stream
-) {
-    ggml_cuda_pool_alloc<float>  dst_tmp(pool);
-    ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
-
-    if (parallel_blocks > 1) {
-        dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
-        dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
-    }
-
-    constexpr int  frag_m = (cols_per_block) == 8 && (D) % 32 == 0 ? 32 : 16;
-    const     dim3 block_dim(WARP_SIZE, nwarps, 1);
-    const     dim3 blocks_num(parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
-    const     int  shmem = 0;
-
-    float scale    = 1.0f;
-    float max_bias = 0.0f;
-
-    memcpy(&scale,    (float *) KQV->op_params + 0, sizeof(float));
-    memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
-
-    const uint32_t n_head      = Q->ne[2];
-    const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
-
-    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
-    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
-
-    flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>
-        <<<blocks_num, block_dim, shmem, main_stream>>> (
-                (const char *) Q->data,
-                (const char *) K->data,
-                (const char *) V->data,
-                mask ? ((const char *) mask->data) : nullptr,
-                (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
-                scale, max_bias, m0, m1, n_head_log2,
-                Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
-                K->ne[0], K->ne[1], K->ne[2], K->ne[3],
-                mask ? mask->ne[1] : 0, mask ?  mask->nb[1] : 0,
-                Q->nb[1], Q->nb[2], Q->nb[3],
-                K->nb[1], K->nb[2], K->nb[3],
-                KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
-                );
-    CUDA_CHECK(cudaGetLastError());
-
-    if ((parallel_blocks) == 1) {
-        return;
-    }
-
-    const dim3 block_dim_combine(D, 1, 1);
-    const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
-    const int  shmem_combine = 0;
-
-    flash_attn_combine_results<D, parallel_blocks>
-        <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
-        (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
-    CUDA_CHECK(cudaGetLastError());
-}
+template <int D, int cols_per_block, int nwarps, typename KQ_acc_t>
+void launch_fattn_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * Q = dst->src[0];
 
-template <int D, int cols_per_block, int nwarps, typename KQ_acc_t> void launch_fattn_f16(
-        const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV, const ggml_tensor * mask,
-        const int nsm, ggml_cuda_pool & pool, cudaStream_t main_stream
-) {
+    constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
     const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
+    const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
 
     if (4*blocks_num_pb1 < 2*nsm) {
-        launch_fattn_f16_impl<D, cols_per_block, nwarps, 4, KQ_acc_t>(Q, K, V, KQV, mask, pool, main_stream);
+        constexpr int parallel_blocks = 4;
+        fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
+        launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
         return;
     }
     if (2*blocks_num_pb1 < 2*nsm) {
-        launch_fattn_f16_impl<D, cols_per_block, nwarps, 2, KQ_acc_t>(Q, K, V, KQV, mask, pool, main_stream);
+        constexpr int parallel_blocks = 2;
+        fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
+        launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
         return;
     }
-    launch_fattn_f16_impl<D, cols_per_block, nwarps, 1, KQ_acc_t>(Q, K, V, KQV, mask, pool, main_stream);
+    constexpr int parallel_blocks = 1;
+    fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
+    launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
 }
 
 void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * Q = dst->src[0];
-    const ggml_tensor * K = dst->src[1];
-    const ggml_tensor * V = dst->src[2];
-
-    const ggml_tensor * mask = dst->src[3];
-
-    ggml_tensor * KQV = dst;
-
-    GGML_ASSERT(Q->type == GGML_TYPE_F32);
-    GGML_ASSERT(K->type == GGML_TYPE_F16);
-    GGML_ASSERT(V->type == GGML_TYPE_F16);
-    GGML_ASSERT(KQV->type == GGML_TYPE_F32);
-
-    GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
-    GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
-                                "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
-
-    GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
+    const ggml_tensor * KQV = dst;
+    const ggml_tensor * Q   = dst->src[0];
 
     ggml_cuda_set_device(ctx.device);
-
-    const int cc  = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
-    const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
-
+    const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
     const int32_t precision = KQV->op_params[2];
 
     // On AMD the tile kernels perform poorly, use the vec kernel instead:
@@ -582,22 +501,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
             constexpr int nwarps         =  4;
             switch (Q->ne[0]) {
                 case 64:
-                    launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                    launch_fattn_f16< 64, cols_per_block, nwarps, float>(ctx, dst);
                     break;
                 case 80:
-                    launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                    launch_fattn_f16< 80, cols_per_block, nwarps, float>(ctx, dst);
                     break;
                 case 96:
-                    launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                    launch_fattn_f16< 96, cols_per_block, nwarps, float>(ctx, dst);
                     break;
                 case 112:
-                    launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                    launch_fattn_f16<112, cols_per_block, nwarps, float>(ctx, dst);
                     break;
                 case 128:
-                    launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                    launch_fattn_f16<128, cols_per_block, nwarps, float>(ctx, dst);
                     break;
                 case 256:
-                    launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                    launch_fattn_f16<256, cols_per_block, nwarps, float>(ctx, dst);
                     break;
                 default:
                     GGML_ASSERT(false);
@@ -608,22 +527,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
             constexpr int nwarps         =  4;
             switch (Q->ne[0]) {
                 case 64:
-                    launch_fattn_f16< 64, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                    launch_fattn_f16< 64, cols_per_block, nwarps, float>(ctx, dst);
                     break;
                 case 80:
-                    launch_fattn_f16< 80, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                    launch_fattn_f16< 80, cols_per_block, nwarps, float>(ctx, dst);
                     break;
                 case 96:
-                    launch_fattn_f16< 96, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                    launch_fattn_f16< 96, cols_per_block, nwarps, float>(ctx, dst);
                     break;
                 case 112:
-                    launch_fattn_f16<112, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                    launch_fattn_f16<112, cols_per_block, nwarps, float>(ctx, dst);
                     break;
                 case 128:
-                    launch_fattn_f16<128, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                    launch_fattn_f16<128, cols_per_block, nwarps, float>(ctx, dst);
                     break;
                 // case 256:
-                //     launch_fattn_f16<256, cols_per_block, nwarps, float>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                //     launch_fattn_f16<256, cols_per_block, nwarps, float>(ctx, dst);
                 //     break;
                 default:
                     GGML_ASSERT(false);
@@ -643,16 +562,16 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
         constexpr int nwarps         = 4;
         switch (Q->ne[0]) {
             case 64:
-                launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                launch_fattn_f16< 64, cols_per_block, nwarps, half>(ctx, dst);
                 break;
             case 96:
-                launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                launch_fattn_f16< 96, cols_per_block, nwarps, half>(ctx, dst);
                 break;
             case 128:
-                launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                launch_fattn_f16<128, cols_per_block, nwarps, half>(ctx, dst);
                 break;
             case 256:
-                launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                launch_fattn_f16<256, cols_per_block, nwarps, half>(ctx, dst);
                 break;
             default:
                 GGML_ASSERT(false);
@@ -666,22 +585,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
         constexpr int nwarps         =  4;
         switch (Q->ne[0]) {
             case 64:
-                launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                launch_fattn_f16< 64, cols_per_block, nwarps, half>(ctx, dst);
                 break;
             case 80:
-                launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                launch_fattn_f16< 80, cols_per_block, nwarps, half>(ctx, dst);
                 break;
             case 96:
-                launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                launch_fattn_f16< 96, cols_per_block, nwarps, half>(ctx, dst);
                 break;
             case 112:
-                launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                launch_fattn_f16<112, cols_per_block, nwarps, half>(ctx, dst);
                 break;
             case 128:
-                launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                launch_fattn_f16<128, cols_per_block, nwarps, half>(ctx, dst);
                 break;
             case 256:
-                launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+                launch_fattn_f16<256, cols_per_block, nwarps, half>(ctx, dst);
                 break;
             default:
                 GGML_ASSERT(false);
@@ -694,22 +613,22 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     constexpr int nwarps         =  4;
     switch (Q->ne[0]) {
         case 64:
-            launch_fattn_f16< 64, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+            launch_fattn_f16< 64, cols_per_block, nwarps, half>(ctx, dst);
             break;
         case 80:
-            launch_fattn_f16< 80, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+            launch_fattn_f16< 80, cols_per_block, nwarps, half>(ctx, dst);
             break;
         case 96:
-            launch_fattn_f16< 96, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+            launch_fattn_f16< 96, cols_per_block, nwarps, half>(ctx, dst);
             break;
         case 112:
-            launch_fattn_f16<112, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+            launch_fattn_f16<112, cols_per_block, nwarps, half>(ctx, dst);
             break;
         case 128:
-            launch_fattn_f16<128, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+            launch_fattn_f16<128, cols_per_block, nwarps, half>(ctx, dst);
             break;
         case 256:
-            launch_fattn_f16<256, cols_per_block, nwarps, half>(Q, K, V, KQV, mask, nsm, ctx.pool(), ctx.stream());
+            launch_fattn_f16<256, cols_per_block, nwarps, half>(ctx, dst);
             break;
         default:
             GGML_ASSERT(false);
index ca85285a3f4583e80d7b487a4d32bf191c332b43..ce64f2f2ce28b1095f7497a96471107c5d4a87c4 100644 (file)
@@ -1,3 +1,4 @@
+#include "common.cuh"
 #include "softmax.cuh"
 
 template <typename T>
@@ -23,17 +24,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
     const int warp_id = threadIdx.x / WARP_SIZE;
     const int lane_id = threadIdx.x % WARP_SIZE;
 
-    float slope = 1.0f;
-
-    // ALiBi
-    if (max_bias > 0.0f) {
-        const int h = rowx/nrows_y; // head index
-
-        const float base = h < n_head_log2 ? m0 : m1;
-        const int   exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
-
-        slope = powf(base, exph);
-    }
+    const float slope = get_alibi_slope(max_bias, rowx/nrows_y, n_head_log2, m0, m1);
 
     extern __shared__ float data_soft_max_f32[];
     float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication