]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Adjust workaround for ROCWMMA_FATTN/GFX9 to only newer ROCm veresions (#19591)
authorMario Limonciello <redacted>
Mon, 16 Feb 2026 13:46:08 +0000 (07:46 -0600)
committerGitHub <redacted>
Mon, 16 Feb 2026 13:46:08 +0000 (14:46 +0100)
Avoids issues with ROCm 6.4.4.

Closes: https://github.com/ggml-org/llama.cpp/issues/19580
Fixes: 6845f7f87 ("Add a workaround for compilation with ROCWMMA_FATTN and gfx9 (#19461)")
Signed-off-by: Mario Limonciello (AMD) <redacted>
ggml/src/ggml-cuda/fattn-wmma-f16.cu

index 35735d48b2e84be81e3ca21009f31f7ccde1b69d..f19defbff939c9c2c1b31d6ec77e921d21aabba3 100644 (file)
@@ -63,7 +63,7 @@ static __global__ void flash_attn_ext_f16(
     constexpr int frag_m = ncols == 8 ? 32 : 16;
     constexpr int frag_n = ncols == 8 ?  8 : 16;
     static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
-#if defined(GGML_USE_HIP)
+#if defined(GGML_USE_HIP) && HIP_VERSION >= 60500000
     typedef wmma::fragment<wmma::matrix_a,    frag_m, frag_n, 16, _Float16, wmma::row_major> frag_a_K;
     typedef wmma::fragment<wmma::matrix_a,    frag_m, frag_n, 16, _Float16, wmma::col_major> frag_a_V;
     typedef wmma::fragment<wmma::matrix_b,    frag_m, frag_n, 16, _Float16, wmma::col_major> frag_b;
@@ -135,7 +135,7 @@ static __global__ void flash_attn_ext_f16(
     __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
     half2 * VKQ2 = (half2 *) VKQ;
 
-#if defined(GGML_USE_HIP)
+#if defined(GGML_USE_HIP) && HIP_VERSION >= 60500000
     const _Float16 * K_h_f16  = reinterpret_cast<const _Float16 *>(K_h);
     const _Float16 * V_h_f16  = reinterpret_cast<const _Float16 *>(V_h);
     _Float16       * KQ_f16   = reinterpret_cast<_Float16 *>(KQ);