]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
cuda : add half2 __shfl_xor() for ROCm 5.5 (llama/7263)
authorEngininja2 <redacted>
Sat, 18 May 2024 08:05:17 +0000 (02:05 -0600)
committerGeorgi Gerganov <redacted>
Tue, 28 May 2024 11:41:08 +0000 (14:41 +0300)
src/ggml-cuda/common.cuh

index b6f0bc36a4f8ae39c37510f867cab29c77045499..784792ba0dfcda65382bd1df4d466b7cd5e398fe 100644 (file)
@@ -315,6 +315,20 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
 #endif
     return c;
 }
+
+#if defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
+// __shfl_xor() for half2 was added in ROCm 5.6
+static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int width) {
+    typedef union half2_b32 {
+        half2 val;
+        int   b32;
+    } half2_b32_t;
+    half2_b32_t tmp;
+    tmp.val = var;
+    tmp.b32 = __shfl_xor(tmp.b32, laneMask, width);
+    return tmp.val;
+}
+#endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
 #endif // defined(GGML_USE_HIPBLAS)
 
 #define FP16_AVAILABLE (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL