]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA/HIP: Fix fattn-vec-* when device warp size is not 32 (llama/12315)
authoruvos <redacted>
Wed, 12 Mar 2025 09:14:11 +0000 (10:14 +0100)
committerGeorgi Gerganov <redacted>
Thu, 27 Mar 2025 09:06:03 +0000 (11:06 +0200)
When fattn-wmma was ported over to warp64 various bits that also touch fattn-vec where converted to
selectable warp size, however the fattn-vec kernels dont work with 64 wide warps for now, so we need
to avoid launching them with parameters for warp64

ggml/src/ggml-cuda/fattn-common.cuh
ggml/src/ggml-cuda/fattn-wmma-f16.cu

index 46de14093545c64f0ec50b6e951e93565dea4e31..4067fd41bc247947b788d2cd6214ae1b2d263a0f 100644 (file)
@@ -52,12 +52,11 @@ typedef half (*vec_dot_KQ_f16_t)(
 typedef float (*vec_dot_KQ_f32_t)(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
 
-template<typename T, int D>
+template<typename T, int D, int warp_size>
 static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
 
     const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
-    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
     GGML_UNUSED(Q_v);
 
     T sum = 0.0f;
@@ -93,12 +92,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
     return sum;
 }
 
-template<typename T, int D>
+template<typename T, int D, int warp_size>
 static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
 
     const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
-    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
     GGML_UNUSED(Q_v);
 
     T sum = 0.0f;
@@ -138,12 +136,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
     return sum;
 }
 
-template<typename T, int D>
+template<typename T, int D, int warp_size>
 static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
 
     const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
-    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
     GGML_UNUSED(Q_v);
 
     T sum = 0.0f;
@@ -186,12 +183,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
     return sum;
 }
 
-template<typename T, int D>
+template<typename T, int D, int warp_size>
 static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
 
     const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
-    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
     GGML_UNUSED(Q_v);
 
     T sum = 0.0f;
@@ -238,12 +234,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
     return sum;
 }
 
-template <typename T, int D>
+template <typename T, int D, int warp_size>
 static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
 
     const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
-    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
     GGML_UNUSED(Q_v);
 
     T sum = 0.0f;
@@ -272,12 +267,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
     return sum;
 }
 
-template <typename T, int D>
+template <typename T, int D, int warp_size>
 static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
 
     const half2 * K_h2 = (const half2 *) K_c;
-    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
     GGML_UNUSED(Q_q8);
     GGML_UNUSED(Q_ds_v);
 
@@ -480,25 +474,25 @@ static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ v
     return x[i];
 }
 
-template <int D>
+template <int D, int warp_size = WARP_SIZE>
 constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) {
-    return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D> :
-        type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D> :
-        type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D> :
-        type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D> :
-        type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D> :
-        type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D> :
+    return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D, warp_size> :
+        type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D, warp_size> :
+        type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D, warp_size> :
+        type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D, warp_size> :
+        type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D, warp_size> :
+        type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D, warp_size> :
         nullptr;
 }
 
-template <int D>
+template <int D, int warp_size = WARP_SIZE>
 constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) {
-    return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float, D> :
-        type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float, D> :
-        type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float, D> :
-        type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, D> :
-        type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, D> :
-        type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, D> :
+    return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float, D, warp_size> :
+        type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float, D, warp_size> :
+        type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float, D, warp_size> :
+        type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, D, warp_size> :
+        type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, D, warp_size> :
+        type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, D, warp_size> :
         nullptr;
 }
 
@@ -681,7 +675,8 @@ static void on_no_fattn_vec_case(const int D) {
 template <int D, int ncols1, int ncols2, int parallel_blocks, int KQ_stride>
 void launch_fattn(
     ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
-    const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
+    const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V,
+    const int warp_size = WARP_SIZE
 ) {
     constexpr int ncols = ncols1 * ncols2;
 
@@ -704,8 +699,6 @@ void launch_fattn(
 
     GGML_ASSERT(Q->ne[3] == 1);
 
-    const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
-
     ggml_cuda_pool & pool = ctx.pool();
     cudaStream_t main_stream = ctx.stream();
     const int id  = ggml_cuda_get_device();
@@ -805,7 +798,6 @@ void launch_fattn(
     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
     GGML_ASSERT(block_dim.x % warp_size == 0);
-    GGML_ASSERT(!GGML_CUDA_CC_IS_AMD(cc) || block_dim.x * block_dim.y <= 4 * (unsigned int)warp_size);
     fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
         (const char *) Q->data,
         K_data,
index 622cf28576d2997361b05a7bd6fd475b25f6fadb..dab1d5cbcace4b43fdd9aa78c6531ef92b2aacc8 100644 (file)
@@ -469,6 +469,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
     constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
     const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
     const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
+    const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
 
     float logit_softcap;
     memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
@@ -485,7 +486,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
             fattn_kernel = flash_attn_ext_f16<
                 D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
         }
-        launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
+        launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size);
         return;
     }
     if (2*blocks_num_pb1 < 2*nsm) {
@@ -500,7 +501,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
             fattn_kernel = flash_attn_ext_f16<
                 D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
         }
-        launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
+        launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size);
         return;
     }
     constexpr int parallel_blocks = 1;
@@ -514,7 +515,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
         fattn_kernel = flash_attn_ext_f16<
             D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
     }
-    launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
+    launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size);
 }
 
 void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {