]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: CUDART < 11.7 workaround for __hmax, __hmax2 (llama/7019)
authorJohannes Gäßler <redacted>
Wed, 1 May 2024 12:46:37 +0000 (14:46 +0200)
committerGeorgi Gerganov <redacted>
Sat, 11 May 2024 18:30:08 +0000 (21:30 +0300)
src/ggml-cuda/common.cuh
src/ggml-cuda/fattn.cu

index 156eba6d1ef74d780ccaa74e72eea9428ccac7e8..b2627b7b4b77ff65a17de8199b3163a6a3c9c6be 100644 (file)
 #define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)
 
 #define WARP_SIZE 32
-#define CUDART_HMAX     11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
+#define CUDART_HMAX   11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
+#define CUDART_HMASK  12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
 
 #define CC_PASCAL     600
 #define MIN_CC_DP4A   610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
@@ -293,20 +294,54 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
     return x;
 }
 
+static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+
+#if CUDART_VERSION >= CUDART_HMAX
+    return __hmax(a, b);
+#else
+    return __half2float(a) > __half2float(b) ? a : b;
+#endif // CUDART_VERSION >= CUDART_HMAX
+
+#else
+    GGML_UNUSED(a);
+    GGML_UNUSED(b);
+    NO_DEVICE_CODE;
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
+}
+static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
+
+#if CUDART_VERSION >= CUDART_HMAX
+    return __hmax2(a, b);
+#else
+    half2 ret;
+    reinterpret_cast<half&>(ret.x) =  __low2float(a) >  __low2float(b) ?  __low2half(a) :  __low2half(b);
+    reinterpret_cast<half&>(ret.y) = __high2float(a) > __high2float(b) ? __high2half(a) : __high2half(b);
+    return ret;
+#endif // CUDART_VERSION >= CUDART_HMAX
+
+#else
+    GGML_UNUSED(a);
+    GGML_UNUSED(b);
+    NO_DEVICE_CODE;
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
+}
+
 static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
-#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
 #pragma unroll
    for (int mask = 16; mask > 0; mask >>= 1) {
-       x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
+       x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
    }
    return x;
 #else
    GGML_UNUSED(x);
    NO_DEVICE_CODE;
-#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
 }
 
-#if CUDART_VERSION < 12000
+#if CUDART_VERSION < CUDART_HMASK
 static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half2 b) {
     const uint32_t mask_low  = 0x0000FFFF * (float( __low2half(a)) > float( __low2half(b)));
     const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b)));
index df1e80068b3345eb5714babb3fceaa216bc3bb88..c8a11d173346454f05f1d5cf5ed70627db9efeea 100644 (file)
@@ -116,7 +116,7 @@ static __global__ void flash_attn_vec_ext_f16(
             sum2 = warp_reduce_sum(sum2);
             half sum = __low2half(sum2) + __high2half(sum2);
             sum += mask ? maskh[k_VKQ_0 + i_KQ] : __float2half(0.0f);
-            kqmax_new = __hmax(kqmax_new, sum);
+            kqmax_new = ggml_cuda_hmax(kqmax_new, sum);
             if (threadIdx.x == 0) {
                 KQ[i_KQ] = sum;
             }
@@ -416,9 +416,9 @@ static __global__ void flash_attn_ext_f16(
                     const int k = k0 + threadIdx.x;
 
                     KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
-                    KQ_max_new = __hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
+                    KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
                 }
-                KQ_max_new = __half2half2(warp_reduce_max(__hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
+                KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
                 const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
                 KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
                 const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));