]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
HIP: Prepare reduction operators for wave 64
authoruvos <redacted>
Wed, 29 Jan 2025 18:12:42 +0000 (19:12 +0100)
committeruvos <redacted>
Thu, 30 Jan 2025 15:25:44 +0000 (16:25 +0100)
ggml/src/ggml-cuda/common.cuh
ggml/src/ggml-cuda/ggml-cuda.cu

index eec227dce3a1e878f8dd7a6e2ed90a35cd26b6ac..8d8d3932e0e5864a9343b467079a53f3f7107bb0 100644 (file)
@@ -190,53 +190,46 @@ static __device__ void no_device_code(
 #define NO_DEVICE_CODE //GGML_ABORT("NO_DEVICE_CODE not valid in host code.")
 #endif // __CUDA_ARCH__
 
+template<int width = WARP_SIZE>
 static __device__ __forceinline__ int warp_reduce_sum(int x) {
 #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
     return __reduce_add_sync(0xffffffff, x);
 #else
 #pragma unroll
-    for (int offset = 16; offset > 0; offset >>= 1) {
-        x += __shfl_xor_sync(0xffffffff, x, offset, 32);
+    for (int offset = width/2; offset > 0; offset >>= 1) {
+        x += __shfl_xor_sync(0xffffffff, x, offset, width);
     }
     return x;
 #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 }
 
+template<int width = WARP_SIZE>
 static __device__ __forceinline__ float warp_reduce_sum(float x) {
 #pragma unroll
-    for (int offset = 16; offset > 0; offset >>= 1) {
-        x += __shfl_xor_sync(0xffffffff, x, offset, 32);
+    for (int offset = width/2; offset > 0; offset >>= 1) {
+        x += __shfl_xor_sync(0xffffffff, x, offset, width);
     }
     return x;
 }
 
+template<int width = WARP_SIZE>
 static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
 #pragma unroll
-    for (int offset = 16; offset > 0; offset >>= 1) {
-        a.x += __shfl_xor_sync(0xffffffff, a.x, offset, 32);
-        a.y += __shfl_xor_sync(0xffffffff, a.y, offset, 32);
+    for (int offset = width/2; offset > 0; offset >>= 1) {
+        a.x += __shfl_xor_sync(0xffffffff, a.x, offset, width);
+        a.y += __shfl_xor_sync(0xffffffff, a.y, offset, width);
     }
     return a;
 }
 
+template<int width = WARP_SIZE>
 static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
 #ifdef FP16_AVAILABLE
-
-#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
-#pragma unroll
-    for (int offset = 16; offset > 0; offset >>= 1) {
-        const half2 a_other = __shfl_xor_sync(0xffffffff, a, offset, 32);
-        reinterpret_cast<half&>(a.x) +=  __low2half(a_other);
-        reinterpret_cast<half&>(a.y) += __high2half(a_other);
-    }
-    return a;
-#else
 #pragma unroll
-    for (int offset = 16; offset > 0; offset >>= 1) {
-        a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, 32));
+    for (int offset = width/2; offset > 0; offset >>= 1) {
+        a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, offset, width));
     }
     return a;
-#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
 
 #else
     NO_DEVICE_CODE;
@@ -244,10 +237,11 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
 #endif // FP16_AVAILABLE
 }
 
+template<int width = WARP_SIZE>
 static __device__ __forceinline__ float warp_reduce_max(float x) {
 #pragma unroll
-    for (int offset = 16; offset > 0; offset >>= 1) {
-        x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, 32));
+    for (int offset = width/2; offset > 0; offset >>= 1) {
+        x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, width));
     }
     return x;
 }
@@ -269,35 +263,34 @@ static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b
 }
 
 static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
-
-#if CUDART_VERSION >= CUDART_HMAX
+#if defined(GGML_USE_HIP) && HIP_VERSION >= 50700000
+    return half2(__hmax(a.x, b.x), __hmax(a.y, b.y));
+#elif !defined(GGML_USE_HIP) && CUDART_VERSION >= CUDART_HMAX
     return __hmax2(a, b);
-#else
+#elif !defined(GGML_USE_HIP)
     half2 ret;
     reinterpret_cast<half&>(ret.x) = __float2half(fmaxf( __low2float(a),  __low2float(b)));
     reinterpret_cast<half&>(ret.y) = __float2half(fmaxf(__high2float(a), __high2float(b)));
     return ret;
-#endif // CUDART_VERSION >= CUDART_HMAX
-
 #else
     GGML_UNUSED(a);
     GGML_UNUSED(b);
     NO_DEVICE_CODE;
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
+#endif
 }
 
+template<int width = WARP_SIZE>
 static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000)
 #pragma unroll
-   for (int offset = 16; offset > 0; offset >>= 1) {
-       x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, 32));
+   for (int offset = width/2; offset > 0; offset >>= 1) {
+       x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, width));
    }
    return x;
 #else
    GGML_UNUSED(x);
    NO_DEVICE_CODE;
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL || (defined(GGML_USE_HIP) && HIP_VERSION >= 50700000)
 }
 
 #if CUDART_VERSION < CUDART_HMASK
index ecf06fec408bb822b1e96368bc80334e067369c4..383131c7789d5802616b5db819ee890b3226742b 100644 (file)
@@ -240,8 +240,8 @@ static ggml_cuda_device_info ggml_cuda_init() {
         info.default_tensor_split[id] = total_vram;
         total_vram += prop.totalGlobalMem;
 
-        info.devices[id].nsm   = prop.multiProcessorCount;
-        info.devices[id].smpb  = prop.sharedMemPerBlock;
+        info.devices[id].nsm       = prop.multiProcessorCount;
+        info.devices[id].smpb      = prop.sharedMemPerBlock;
         info.devices[id].warp_size = prop.warpSize;
 #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
         info.devices[id].smpbo = prop.sharedMemPerBlock;