]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: fix BF16 FA compilation (llama/20865)
authorJohannes Gäßler <redacted>
Sun, 22 Mar 2026 16:53:33 +0000 (17:53 +0100)
committerGeorgi Gerganov <redacted>
Sat, 28 Mar 2026 11:39:09 +0000 (13:39 +0200)
src/ggml-cuda/convert.cuh

index b8caeacf094286b5e841c62b92fc5892c216c4b4..f5d37c7b99874b13a04dd2b5d3a1532fc3d5612c 100644 (file)
@@ -42,11 +42,15 @@ template<typename dst_t, typename src_t>
     } else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, half2>) {
         return __float22half2_rn(x);
     } else if constexpr(std::is_same_v<src_t, nv_bfloat162> && std::is_same_v<dst_t, float2>) {
-#if !defined(GGML_USE_HIP) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+#ifdef GGML_USE_HIP
+        return make_float2(__bfloat162float(__low2bfloat16(x)), __bfloat162float(__high2bfloat16(x)));
+#else
+#if __CUDA_ARCH__ >= 800
         return __bfloat1622float2(x);
 #else
-        return make_float2(__bfloat162float(__low2bfloat16(x)), __bfloat162float(__high2bfloat16(x)));
-#endif // !defined(GGML_USE_HIP) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
+        return make_float2(__bfloat162float(x.x), __bfloat162float(x.y));
+#endif // __CUDA_ARCH__ >= 800
+#endif // GGML_USE_HIP
     } else if constexpr(std::is_same_v<src_t, float2> && std::is_same_v<dst_t, nv_bfloat162>) {
         // bypass compile error on cuda 12.0.1
 #ifdef GGML_USE_HIP