From: Johannes Gäßler Date: Tue, 9 Sep 2025 12:04:43 +0000 (+0200) Subject: HIP: use v_dot2_f32_f16 instruction for FA (#15884) X-Git-Tag: upstream/0.0.6527~97 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=17bc5a815f0bebc1844c99796760cd7df5e9b8d9;p=pkg%2Fggml%2Fsources%2Fllama.cpp HIP: use v_dot2_f32_f16 instruction for FA (#15884) --- diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 931524a2..394595be 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -545,6 +545,31 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i #endif // defined(GGML_USE_HIP) } +static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float v, const float u) { + acc += v*u; +} + +static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float2 v, const float2 u) { + acc += v.x*u.x; + acc += v.y*u.y; +} + +static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) { +#if defined(GGML_USE_HIP) && defined(GCN) + asm volatile("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(acc) : "v"(v), "v"(u)); +#else +#ifdef FAST_FP16_AVAILABLE + const float2 tmp = __half22float2(v*u); + acc += tmp.x + tmp.y; +#else + const float2 tmpv = __half22float2(v); + const float2 tmpu = __half22float2(u); + acc += tmpv.x * tmpu.x; + acc += tmpv.y * tmpu.y; +#endif // FAST_FP16_AVAILABLE +#endif // defined(GGML_USE_HIP) && defined(GCN) +} + static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) { #if CUDART_VERSION >= 12080 const nv_bfloat16 e = __nv_cvt_e8m0_to_bf16raw(x); diff --git a/ggml/src/ggml-cuda/fattn-tile.cu b/ggml/src/ggml-cuda/fattn-tile.cu index fb2163ac..64f7d4a1 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cu +++ b/ggml/src/ggml-cuda/fattn-tile.cu @@ -304,12 +304,7 @@ static __global__ void flash_attn_tile( for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) { #pragma unroll for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) { -#ifdef FAST_FP16_AVAILABLE - const float2 tmp = __half22float2(K_k[i_KQ_0/warp_size] * Q_k[j_KQ_0/nwarps]); - sum[i_KQ_0/warp_size][j_KQ_0/nwarps] += tmp.x + tmp.y; -#else - sum[i_KQ_0/warp_size][j_KQ_0/nwarps] += K_k[i_KQ_0/warp_size] * Q_k[j_KQ_0/nwarps]; -#endif // FAST_FP16_AVAILABLE + ggml_cuda_mad(sum[i_KQ_0/warp_size][j_KQ_0/nwarps], K_k[i_KQ_0/warp_size], Q_k[j_KQ_0/nwarps]); } } }