From: Engininja2 Date: Sat, 18 May 2024 08:05:17 +0000 (-0600) Subject: cuda : add half2 __shfl_xor() for ROCm 5.5 (llama/7263) X-Git-Tag: upstream/0.0.1642~675 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=0d7036940257c60e2a828896be8073a860dec860;p=pkg%2Fggml%2Fsources%2Fggml cuda : add half2 __shfl_xor() for ROCm 5.5 (llama/7263) --- diff --git a/src/ggml-cuda/common.cuh b/src/ggml-cuda/common.cuh index b6f0bc36..784792ba 100644 --- a/src/ggml-cuda/common.cuh +++ b/src/ggml-cuda/common.cuh @@ -315,6 +315,20 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) { #endif return c; } + +#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000 +// __shfl_xor() for half2 was added in ROCm 5.6 +static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int width) { + typedef union half2_b32 { + half2 val; + int b32; + } half2_b32_t; + half2_b32_t tmp; + tmp.val = var; + tmp.b32 = __shfl_xor(tmp.b32, laneMask, width); + return tmp.val; +} +#endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000 #endif // defined(GGML_USE_HIPBLAS) #define FP16_AVAILABLE (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL