]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
2x faster (rms) norm cuda kernels (3.7% e2e improvement) (#2985)
authorJiahao Li <redacted>
Mon, 4 Sep 2023 06:53:30 +0000 (14:53 +0800)
committerGitHub <redacted>
Mon, 4 Sep 2023 06:53:30 +0000 (08:53 +0200)
* 2x faster (rms) norm cuda kernels

* Fix code style

ggml-cuda.cu

index 8357f32f7c60d4444034d0d8ff7026f10e025fb4..d2dbf824ef2da9145161d1f43dbfedd7e107ce6e 100644 (file)
@@ -464,58 +464,91 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
     dst[i] = x[i] / (1.0f + expf(-x[i]));
 }
 
+static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
+        a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
+    }
+    return a;
+}
+
+template <int block_size>
 static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
     const int row = blockIdx.x*blockDim.y + threadIdx.y;
     const int tid = threadIdx.x;
 
     const float eps = 1e-5f;
 
-    float mean = 0.0f;
-    float var = 0.0f;
+    float2 mean_var = make_float2(0.f, 0.f);
 
-    for (int col = tid; col < ncols; col += WARP_SIZE) {
+    for (int col = tid; col < ncols; col += block_size) {
         const float xi = x[row*ncols + col];
-        mean += xi;
-        var += xi * xi;
+        mean_var.x += xi;
+        mean_var.y += xi * xi;
     }
 
     // sum up partial sums
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        mean += __shfl_xor_sync(0xffffffff, mean, mask, 32);
-        var += __shfl_xor_sync(0xffffffff, var, mask, 32);
+    mean_var = warp_reduce_sum(mean_var);
+    if (block_size > WARP_SIZE) {
+        __shared__ float2 s_sum[32];
+        int warp_id = threadIdx.x / WARP_SIZE;
+        int lane_id = threadIdx.x % WARP_SIZE;
+        if (lane_id == 0) {
+            s_sum[warp_id] = mean_var;
+        }
+        __syncthreads();
+        mean_var = s_sum[lane_id];
+        mean_var = warp_reduce_sum(mean_var);
     }
 
-    mean /= ncols;
-    var = var / ncols - mean * mean;
-    const float inv_var = rsqrtf(var + eps);
+    const float mean = mean_var.x / ncols;
+    const float var = mean_var.y / ncols - mean * mean;
+    const float inv_std = rsqrtf(var + eps);
 
-    for (int col = tid; col < ncols; col += WARP_SIZE) {
-        dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_var;
+    for (int col = tid; col < ncols; col += block_size) {
+        dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std;
     }
 }
 
+static __device__ __forceinline__ float warp_reduce_sum(float x) {
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        x += __shfl_xor_sync(0xffffffff, x, mask, 32);
+    }
+    return x;
+}
+
+template <int block_size>
 static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
     const int row = blockIdx.x*blockDim.y + threadIdx.y;
     const int tid = threadIdx.x;
 
     float tmp = 0.0f; // partial sum for thread in warp
 
-    for (int col = tid; col < ncols; col += WARP_SIZE) {
+    for (int col = tid; col < ncols; col += block_size) {
         const float xi = x[row*ncols + col];
         tmp += xi * xi;
     }
 
     // sum up partial sums
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    tmp = warp_reduce_sum(tmp);
+    if (block_size > WARP_SIZE) {
+        __shared__ float s_sum[32];
+        int warp_id = threadIdx.x / WARP_SIZE;
+        int lane_id = threadIdx.x % WARP_SIZE;
+        if (lane_id == 0) {
+            s_sum[warp_id] = tmp;
+        }
+        __syncthreads();
+        tmp = s_sum[lane_id];
+        tmp = warp_reduce_sum(tmp);
     }
 
     const float mean = tmp / ncols;
     const float scale = rsqrtf(mean + eps);
 
-    for (int col = tid; col < ncols; col += WARP_SIZE) {
+    for (int col = tid; col < ncols; col += block_size) {
         dst[row*ncols + col] = scale * x[row*ncols + col];
     }
 }
@@ -4203,14 +4236,24 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
 
 static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % WARP_SIZE == 0);
-    const dim3 block_dims(WARP_SIZE, 1, 1);
-    norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
+    if (ncols < 1024) {
+        const dim3 block_dims(WARP_SIZE, 1, 1);
+        norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
+    } else {
+        const dim3 block_dims(1024, 1, 1);
+        norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
+    }
 }
 
 static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
     GGML_ASSERT(ncols % WARP_SIZE == 0);
-    const dim3 block_dims(WARP_SIZE, 1, 1);
-    rms_norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
+    if (ncols < 1024) {
+        const dim3 block_dims(WARP_SIZE, 1, 1);
+        rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
+    } else {
+        const dim3 block_dims(1024, 1, 1);
+        rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
+    }
 }
 
 static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, const int ky, const int kx_padded, cudaStream_t stream) {