]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: fix Pascal FA, deq. KV to FP16 for batch > 8 (llama/7681)
authorJohannes Gäßler <redacted>
Sat, 1 Jun 2024 13:47:04 +0000 (15:47 +0200)
committerGeorgi Gerganov <redacted>
Sat, 15 Jun 2024 19:05:47 +0000 (22:05 +0300)
src/ggml-cuda/fattn-common.cuh
src/ggml-cuda/fattn-tile-f16.cu
src/ggml-cuda/fattn-tile-f32.cu
src/ggml-cuda/fattn-vec-f16.cuh
src/ggml-cuda/fattn-vec-f32.cuh
src/ggml-cuda/fattn-wmma-f16.cuh
src/ggml-cuda/fattn.cu

index 4bf03a49fb8675706c2726415ae4ac798e009330..c00f8606a5c850302e156bb462fddb5c113fee18 100644 (file)
@@ -1,6 +1,7 @@
 #pragma once
 
 #include "common.cuh"
+#include "convert.cuh"
 #include "vecdotq.cuh"
 
 #include <cstdint>
@@ -53,7 +54,7 @@ typedef float (*vec_dot_KQ_f32_t)(
 template<typename T, int D>
 static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
-#if __CUDA_ARCH__ > MIN_CC_DP4A
+#if __CUDA_ARCH__ >= MIN_CC_DP4A
 
     const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
     GGML_UNUSED(Q_v);
@@ -95,13 +96,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
     GGML_UNUSED(Q_q8);
     GGML_UNUSED(Q_ds_v);
     NO_DEVICE_CODE;
-#endif  // __CUDA_ARCH__ > MIN_CC_DP4A
+#endif  // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
 template<typename T, int D>
 static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
-#if __CUDA_ARCH__ > MIN_CC_DP4A
+#if __CUDA_ARCH__ >= MIN_CC_DP4A
 
     const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
     GGML_UNUSED(Q_v);
@@ -147,13 +148,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
     GGML_UNUSED(Q_q8);
     GGML_UNUSED(Q_ds_v);
     NO_DEVICE_CODE;
-#endif  // __CUDA_ARCH__ > MIN_CC_DP4A
+#endif  // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
 template<typename T, int D>
 static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
-#if __CUDA_ARCH__ > MIN_CC_DP4A
+#if __CUDA_ARCH__ >= MIN_CC_DP4A
 
     const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
     GGML_UNUSED(Q_v);
@@ -202,13 +203,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
     GGML_UNUSED(Q_q8);
     GGML_UNUSED(Q_ds_v);
     NO_DEVICE_CODE;
-#endif  // __CUDA_ARCH__ > MIN_CC_DP4A
+#endif  // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
 template<typename T, int D>
 static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
-#if __CUDA_ARCH__ > MIN_CC_DP4A
+#if __CUDA_ARCH__ >= MIN_CC_DP4A
 
     const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
     GGML_UNUSED(Q_v);
@@ -261,13 +262,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
     GGML_UNUSED(Q_q8);
     GGML_UNUSED(Q_ds_v);
     NO_DEVICE_CODE;
-#endif  // __CUDA_ARCH__ > MIN_CC_DP4A
+#endif  // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
 template <typename T, int D>
 static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
-#if __CUDA_ARCH__ > MIN_CC_DP4A
+#if __CUDA_ARCH__ >= MIN_CC_DP4A
 
     const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
     GGML_UNUSED(Q_v);
@@ -302,7 +303,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
     GGML_UNUSED(Q_q8);
     GGML_UNUSED(Q_ds_v);
     NO_DEVICE_CODE;
-#endif  // __CUDA_ARCH__ > MIN_CC_DP4A
+#endif  // __CUDA_ARCH__ >= MIN_CC_DP4A
 }
 
 template <typename T, int D>
@@ -620,7 +621,10 @@ static void on_no_fattn_vec_case(const int D) {
 }
 
 template <int D, int parallel_blocks>
-void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, int nwarps, int cols_per_block) {
+void launch_fattn(
+    ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
+    const int nwarps, const int cols_per_block, const bool need_f16_K, const bool need_f16_V
+) {
     const ggml_tensor * Q = dst->src[0];
     const ggml_tensor * K = dst->src[1];
     const ggml_tensor * V = dst->src[2];
@@ -641,9 +645,49 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern
     ggml_cuda_pool & pool = ctx.pool();
     cudaStream_t main_stream = ctx.stream();
 
+    ggml_cuda_pool_alloc<half>   K_f16(pool);
+    ggml_cuda_pool_alloc<half>   V_f16(pool);
     ggml_cuda_pool_alloc<float>  dst_tmp(pool);
     ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
 
+    char * K_data = (char *) K->data;
+    size_t nb11 = K->nb[1];
+    size_t nb12 = K->nb[2];
+    size_t nb13 = K->nb[3];
+
+    char * V_data = (char *) V->data;
+    size_t nb21 = V->nb[1];
+    size_t nb22 = V->nb[2];
+    size_t nb23 = V->nb[3];
+
+    if (need_f16_K && K->type != GGML_TYPE_F16) {
+        K_f16.alloc(ggml_nelements(K));
+        to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
+        to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
+        K_data = (char *) K_f16.ptr;
+
+        const size_t bs = ggml_blck_size(K->type);
+        const size_t ts = ggml_type_size(K->type);
+
+        nb11 = nb11*bs*sizeof(half)/ts;
+        nb12 = nb12*bs*sizeof(half)/ts;
+        nb13 = nb13*bs*sizeof(half)/ts;
+    }
+
+    if (need_f16_V && V->type != GGML_TYPE_F16) {
+        V_f16.alloc(ggml_nelements(V));
+        to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
+        to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
+        V_data = (char *) V_f16.ptr;
+
+        const size_t bs = ggml_blck_size(V->type);
+        const size_t ts = ggml_type_size(V->type);
+
+        nb21 = nb21*bs*sizeof(half)/ts;
+        nb22 = nb22*bs*sizeof(half)/ts;
+        nb23 = nb23*bs*sizeof(half)/ts;
+    }
+
     if (parallel_blocks > 1) {
         dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
         dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
@@ -667,8 +711,8 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern
 
     fattn_kernel<<<blocks_num, block_dim, shmem, main_stream>>>(
         (const char *) Q->data,
-        (const char *) K->data,
-        (const char *) V->data,
+        K_data,
+        V_data,
         mask ? ((const char *) mask->data) : nullptr,
         (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
         scale, max_bias, m0, m1, n_head_log2,
@@ -676,8 +720,8 @@ void launch_fattn(ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kern
         K->ne[0], K->ne[1], K->ne[2], K->ne[3],
         mask ? mask->ne[1] : 0, mask ?  mask->nb[1] : 0,
         Q->nb[1], Q->nb[2], Q->nb[3],
-        K->nb[1], K->nb[2], K->nb[3],
-        V->nb[1], V->nb[2], V->nb[3],
+        nb11, nb12, nb13,
+        nb21, nb22, nb23,
         KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
     );
     CUDA_CHECK(cudaGetLastError());
index 3d64a9eba629d7164acb903872659db2a153a00e..cb11d7212ca28ce94d3c160c432beb32477a24c3 100644 (file)
@@ -278,13 +278,13 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
             constexpr int      D = 64;
             constexpr int nwarps = 8;
             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
-            launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
+            launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
         } break;
         case 128: {
             constexpr int      D = 128;
             constexpr int nwarps = 8;
             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks>;
-            launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
+            launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
         } break;
         default: {
             GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
index 61fce0a7eb0ff62a6f6b9e9825526e2f4ee47667..15e22f495ffaa8a49f3af1a6a27a72c68e29c4b0 100644 (file)
@@ -275,13 +275,13 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
             constexpr int      D = 64;
             constexpr int nwarps = 8;
             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
-            launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
+            launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
         } break;
         case 128: {
             constexpr int      D = 128;
             constexpr int nwarps = 8;
             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks>;
-            launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
+            launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
         } break;
         default: {
             GGML_ASSERT(false && "FlashAttention without tensor cores only supports head sizes 64 and 128.");
index ea4fcc972096da7c146a638ec251603396e66219..9e1aa2c6b688520c1e8d9df843116998bf00900b 100644 (file)
@@ -290,7 +290,9 @@ template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml
 void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     constexpr int nwarps = D/WARP_SIZE;
     fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V>;
-    launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
+    constexpr bool need_f16_K = D != 128;
+    constexpr bool need_f16_V = D != 128 && D != 64;
+    launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
 }
 
 template <int D, ggml_type type_K, ggml_type type_V>
index 3009f0f43c0275e6fccf3fac19d0560af1054eb7..ce23a4ebd0088f5e0618a67cc33ce7d2082b0e70 100644 (file)
@@ -271,7 +271,9 @@ template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml
 void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     constexpr int nwarps = D/WARP_SIZE;
     fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V>;
-    launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
+    constexpr bool need_f16_K = D != 128;
+    constexpr bool need_f16_V = D != 128 && D != 64;
+    launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
 }
 
 template <int D, ggml_type type_K, ggml_type type_V>
index 65ed3185244a6d0ed8ca71aacc1b28341bfc44be..59cd30d7837c90bd41add5e00296e24380d36b2d 100644 (file)
@@ -438,18 +438,18 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
     if (4*blocks_num_pb1 < 2*nsm) {
         constexpr int parallel_blocks = 4;
         fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
-        launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
+        launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
         return;
     }
     if (2*blocks_num_pb1 < 2*nsm) {
         constexpr int parallel_blocks = 2;
         fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
-        launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
+        launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
         return;
     }
     constexpr int parallel_blocks = 1;
     fattn_kernel_t fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>;
-    launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block);
+    launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
 }
 
 #define DECL_FATTN_WMMA_F16_CASE(D, cols_per_block, KQ_acc_t)                         \
index b35ab67a8080cde2f09432ab1da81afbb6e92ddf..38d30b21026314efd87d4dbaff8a23fa6376afa7 100644 (file)
@@ -298,17 +298,13 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
 void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * KQV = dst;
     const ggml_tensor * Q   = dst->src[0];
-    const ggml_tensor * K   = dst->src[1];
-    const ggml_tensor * V   = dst->src[2];
 
     ggml_cuda_set_device(ctx.device);
     const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
     const int32_t precision = KQV->op_params[2];
 
-    const bool quantized_KV = ggml_is_quantized(K->type) || ggml_is_quantized(V->type);
-
     // On AMD the tile kernels perform poorly, use the vec kernel instead:
-    if (cc >= CC_OFFSET_AMD || quantized_KV) {
+    if (cc >= CC_OFFSET_AMD) {
         if (precision == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
             ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
         } else {