]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
HIP: use v_dot2_f32_f16 instruction for FA (llama/15884)
authorJohannes Gäßler <redacted>
Tue, 9 Sep 2025 12:04:43 +0000 (14:04 +0200)
committerGeorgi Gerganov <redacted>
Sat, 20 Sep 2025 10:33:50 +0000 (13:33 +0300)
src/ggml-cuda/common.cuh
src/ggml-cuda/fattn-tile.cu

index 931524a2008974889e47868913ff93dbc004dfaa..394595be0eada63bb91e5b5b84fd718485e9061c 100644 (file)
@@ -545,6 +545,31 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
 #endif // defined(GGML_USE_HIP)
 }
 
+static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float v, const float u) {
+    acc += v*u;
+}
+
+static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float2 v, const float2 u) {
+    acc += v.x*u.x;
+    acc += v.y*u.y;
+}
+
+static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) {
+#if defined(GGML_USE_HIP) && defined(GCN)
+    asm volatile("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(acc) : "v"(v), "v"(u));
+#else
+#ifdef FAST_FP16_AVAILABLE
+    const float2 tmp = __half22float2(v*u);
+    acc += tmp.x + tmp.y;
+#else
+    const float2 tmpv = __half22float2(v);
+    const float2 tmpu = __half22float2(u);
+    acc += tmpv.x * tmpu.x;
+    acc += tmpv.y * tmpu.y;
+#endif // FAST_FP16_AVAILABLE
+#endif // defined(GGML_USE_HIP) && defined(GCN)
+}
+
 static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
 #if CUDART_VERSION >= 12080
     const nv_bfloat16 e = __nv_cvt_e8m0_to_bf16raw(x);
index fb2163acd0d73fc1738cece4856c28f3cf78d6b9..64f7d4a1a14701851fac6f70c8f2aa5a5e4db53e 100644 (file)
@@ -304,12 +304,7 @@ static __global__ void flash_attn_tile(
                 for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
 #pragma unroll
                     for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
-#ifdef FAST_FP16_AVAILABLE
-                        const float2 tmp = __half22float2(K_k[i_KQ_0/warp_size] * Q_k[j_KQ_0/nwarps]);
-                        sum[i_KQ_0/warp_size][j_KQ_0/nwarps] += tmp.x + tmp.y;
-#else
-                        sum[i_KQ_0/warp_size][j_KQ_0/nwarps] += K_k[i_KQ_0/warp_size] * Q_k[j_KQ_0/nwarps];
-#endif // FAST_FP16_AVAILABLE
+                        ggml_cuda_mad(sum[i_KQ_0/warp_size][j_KQ_0/nwarps], K_k[i_KQ_0/warp_size], Q_k[j_KQ_0/nwarps]);
                     }
                 }
             }