]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Add a workaround for compilation with ROCWMMA_FATTN and gfx9 (#19461)
authorMario Limonciello <redacted>
Thu, 12 Feb 2026 08:38:35 +0000 (02:38 -0600)
committerGitHub <redacted>
Thu, 12 Feb 2026 08:38:35 +0000 (09:38 +0100)
There is an upstream problem [1] with AMD's LLVM 22 fork and
rocWMMA 2.2.0 causing compilation issues on devices without
native fp16 support (CDNA devices).

The specialized types aren't resolved properly:
```
/opt/rocm/include/rocwmma/internal/mfma_impl.hpp:2549:37: error: ambiguous partial specializations of 'amdgcn_mfma<__half, __half, __half, 16, 16, 16>'
 2549 |             using ARegsT = typename Impl::ARegsT;
```

Add a workaround to explicitly declare the types and cast when
compiling with HIP and ROCWMMA_FATTN [2].  When this is actually
fixed upstream some guards can be used to detect and wrap the
version that has the fix to only apply when necessary.

Link: https://github.com/ROCm/rocm-libraries/issues/4398
Link: https://github.com/ggml-org/llama.cpp/issues/19269
Signed-off-by: Mario Limonciello <redacted>
ggml/src/ggml-cuda/fattn-wmma-f16.cu

index 8694fd06c7bc4f31d9702933ca0aa720389ef240..35735d48b2e84be81e3ca21009f31f7ccde1b69d 100644 (file)
@@ -63,11 +63,19 @@ static __global__ void flash_attn_ext_f16(
     constexpr int frag_m = ncols == 8 ? 32 : 16;
     constexpr int frag_n = ncols == 8 ?  8 : 16;
     static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
+#if defined(GGML_USE_HIP)
+    typedef wmma::fragment<wmma::matrix_a,    frag_m, frag_n, 16, _Float16, wmma::row_major> frag_a_K;
+    typedef wmma::fragment<wmma::matrix_a,    frag_m, frag_n, 16, _Float16, wmma::col_major> frag_a_V;
+    typedef wmma::fragment<wmma::matrix_b,    frag_m, frag_n, 16, _Float16, wmma::col_major> frag_b;
+    typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t>                      frag_c_KQ;
+    typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, _Float16>                          frag_c_VKQ;
+#else
     typedef wmma::fragment<wmma::matrix_a,    frag_m, frag_n, 16, half, wmma::row_major> frag_a_K;
     typedef wmma::fragment<wmma::matrix_a,    frag_m, frag_n, 16, half, wmma::col_major> frag_a_V;
     typedef wmma::fragment<wmma::matrix_b,    frag_m, frag_n, 16, half, wmma::col_major> frag_b;
     typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t>                      frag_c_KQ;
     typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, half>                          frag_c_VKQ;
+#endif
 
     constexpr int KQ_stride_tc  = nwarps*frag_m; // Number of KQ rows calculated in parallel.
     constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
@@ -126,6 +134,19 @@ static __global__ void flash_attn_ext_f16(
 
     __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
     half2 * VKQ2 = (half2 *) VKQ;
+
+#if defined(GGML_USE_HIP)
+    const _Float16 * K_h_f16  = reinterpret_cast<const _Float16 *>(K_h);
+    const _Float16 * V_h_f16  = reinterpret_cast<const _Float16 *>(V_h);
+    _Float16       * KQ_f16   = reinterpret_cast<_Float16 *>(KQ);
+    _Float16       * VKQ_f16  = reinterpret_cast<_Float16 *>(VKQ);
+#else
+    const half * K_h_f16  = K_h;
+    const half * V_h_f16  = V_h;
+    half       * KQ_f16   = KQ;
+    half       * VKQ_f16  = VKQ;
+#endif
+
 #pragma unroll
     for (int j0 = 0; j0 < ncols; j0 += nwarps) {
         const int j = j0 + threadIdx.y;
@@ -160,7 +181,7 @@ static __global__ void flash_attn_ext_f16(
     for (int i0 = 0; i0 < D; i0 += 16) {
 #pragma unroll
         for (int j0 = 0; j0 < ncols; j0 += frag_n) {
-            wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
+            wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ_f16 + j0*D_padded + i0, D_padded);
         }
     }
 
@@ -180,7 +201,7 @@ static __global__ void flash_attn_ext_f16(
 #pragma unroll
             for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
                 frag_a_K K_a;
-                wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
+                wmma::load_matrix_sync(K_a, K_h_f16 + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
 #pragma unroll
                 for (int j = 0; j < ncols/frag_n; ++j) {
                     wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
@@ -310,7 +331,7 @@ static __global__ void flash_attn_ext_f16(
                 const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
                 wmma::load_matrix_sync(
                     KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
-                    KQ + j0*(kqar*kqs_padded) + k,
+                    KQ_f16 + j0*(kqar*kqs_padded) + k,
                     kqar*kqs_padded);
             }
         }
@@ -328,7 +349,7 @@ static __global__ void flash_attn_ext_f16(
                 const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
 
                 frag_a_V v_a;
-                wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
+                wmma::load_matrix_sync(v_a, V_h_f16 + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
 #pragma unroll
                 for (int j = 0; j < ncols/frag_n; ++j) {
                     wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
@@ -344,7 +365,7 @@ static __global__ void flash_attn_ext_f16(
 #pragma unroll
             for (int j0 = 0; j0 < ncols; j0 += frag_n) {
                 wmma::store_matrix_sync(
-                    KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
+                    KQ_f16 + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
                     VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
                     D_padded, wmma::mem_col_major);
             }