]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
cuda : replace remaining shfl_xor with calls to warp_reduce functions (llama/5744)
authorEngininja2 <redacted>
Tue, 27 Feb 2024 13:22:45 +0000 (07:22 -0600)
committerGeorgi Gerganov <redacted>
Wed, 28 Feb 2024 11:00:29 +0000 (13:00 +0200)
ggml-cuda.cu

index 6a4d57cb628ddbab381bd509a558592caaa87277..1242c0410df497efb7e822cd875a78aef86276a3 100644 (file)
@@ -696,18 +696,20 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
     return a;
 }
 
-//static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
-//#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
-//#pragma unroll
-//    for (int mask = 16; mask > 0; mask >>= 1) {
-//        a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
-//    }
-//    return a;
-//#else
-//    (void) a;
-//    NO_DEVICE_CODE;
-//#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
-//}
+#ifdef GGML_CUDA_F16
+static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
+#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
+#pragma unroll
+   for (int mask = 16; mask > 0; mask >>= 1) {
+       a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
+   }
+   return a;
+#else
+   (void) a;
+   NO_DEVICE_CODE;
+#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
+}
+#endif // GGML_CUDA_F16
 
 static __device__ __forceinline__ float warp_reduce_max(float x) {
 #pragma unroll
@@ -2521,10 +2523,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
 #endif
 
     // sum up partial sums and write back result
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
-    }
+    tmp = warp_reduce_sum(tmp);
 
     if (threadIdx.x == 0) {
         dst[row] = tmp;
@@ -2625,10 +2624,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx,
 #endif
 
     // sum up partial sums and write back result
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
-    }
+    tmp = warp_reduce_sum(tmp);
 
     if (threadIdx.x == 0) {
         dst[row] = tmp;
@@ -2761,10 +2757,7 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
 #endif
 
     // sum up partial sums and write back result
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
-    }
+    tmp = warp_reduce_sum(tmp);
 
     if (tid == 0) {
         dst[row] = tmp;
@@ -2877,10 +2870,7 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx,
 #endif
 
     // sum up partial sums and write back result
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
-    }
+    tmp = warp_reduce_sum(tmp);
 
     if (threadIdx.x == 0) {
         dst[row] = tmp;
@@ -2987,10 +2977,7 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
 #endif
 
     // sum up partial sums and write back result
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
-    }
+    tmp = warp_reduce_sum(tmp);
 
     if (tid == 0) {
         dst[row] = tmp;
@@ -3025,11 +3012,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
     float amax = fabsf(xi);
     float sum = xi;
 
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32));
-        sum += __shfl_xor_sync(0xffffffff, sum, mask, 32);
-    }
+    amax = warp_reduce_max(amax);
+    sum = warp_reduce_sum(sum);
 
     const float d = amax / 127;
     const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
@@ -6222,10 +6206,7 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
     }
 
     // sum up partial sums and write back result
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
-    }
+    tmp = warp_reduce_sum(tmp);
 
     if (tid == 0) {
 #ifdef GGML_CUDA_F16
@@ -6275,10 +6256,7 @@ static __global__ void mul_mat_p021_f16_f32(
     const int idst = channel*nrows_dst + row_dst;
 
     // sum up partial sums and write back result
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
-    }
+    tmp = warp_reduce_sum(tmp);
 
     if (threadIdx.x == 0) {
         dst[idst] = tmp;
@@ -6321,10 +6299,7 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
     }
 
     // sum up partial sums and write back result
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
-    }
+    tmp = warp_reduce_sum(tmp);
 
     if (threadIdx.x == 0) {
         dst[idst] = tmp;