]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : support norm op on CUDA (#364)
authorJiahao Li <redacted>
Tue, 11 Jul 2023 17:58:02 +0000 (01:58 +0800)
committerGitHub <redacted>
Tue, 11 Jul 2023 17:58:02 +0000 (20:58 +0300)
src/ggml-cuda.cu

index 1673e7e4c9ef4263b11f89499925dbeed305b643..f9b55173b9a2cbe812086c0f413eec37d9e3f421 100644 (file)
@@ -275,16 +275,46 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
     dst[i] = x[i] / (1.0f + expf(-x[i]));
 }
 
+static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
+    const int row = blockIdx.x*blockDim.y + threadIdx.y;
+    const int tid = threadIdx.x;
+
+    const float eps = 1e-5f;
+
+    float mean = 0.0f;
+    float var = 0.0f;
+
+    for (int col = tid; col < ncols; col += WARP_SIZE) {
+        const float xi = x[row*ncols + col];
+        mean += xi;
+        var += xi * xi;
+    }
+
+    // sum up partial sums
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        mean += __shfl_xor_sync(0xffffffff, mean, mask, 32);
+        var += __shfl_xor_sync(0xffffffff, var, mask, 32);
+    }
+
+    mean /= ncols;
+    var = var / ncols - mean * mean;
+    const float inv_var = rsqrtf(var + eps);
+
+    for (int col = tid; col < ncols; col += WARP_SIZE) {
+        dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_var;
+    }
+}
+
 static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols) {
     const int row = blockIdx.x*blockDim.y + threadIdx.y;
     const int tid = threadIdx.x;
 
-    const float eps = 1e-6;
+    const float eps = 1e-6f;
 
     float tmp = 0.0f; // partial sum for thread in warp
 
-    for (int i = 0; i < ncols; i += WARP_SIZE) {
-        const int col = i + tid;
+    for (int col = tid; col < ncols; col += WARP_SIZE) {
         const float xi = x[row*ncols + col];
         tmp += xi * xi;
     }
@@ -296,10 +326,9 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
     }
 
     const float mean = tmp / ncols;
-    const float scale = 1.0f / sqrtf(mean + eps);
+    const float scale = rsqrtf(mean + eps);
 
-    for (int i = 0; i < ncols; i += WARP_SIZE) {
-        const int col = i + tid;
+    for (int col = tid; col < ncols; col += WARP_SIZE) {
         dst[row*ncols + col] = scale * x[row*ncols + col];
     }
 }
@@ -1709,6 +1738,12 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
     silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
+static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
+    GGML_ASSERT(ncols % WARP_SIZE == 0);
+    const dim3 block_dims(WARP_SIZE, 1, 1);
+    norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
+}
+
 static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
     GGML_ASSERT(ncols % WARP_SIZE == 0);
     const dim3 block_dims(WARP_SIZE, 1, 1);
@@ -2310,6 +2345,28 @@ inline void ggml_cuda_op_silu(
     (void) i1;
 }
 
+inline void ggml_cuda_op_norm(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
+    cudaStream_t & cudaStream_main){
+
+    GGML_ASSERT(src0_ddf_i != nullptr);
+    GGML_ASSERT(dst_ddf_i != nullptr);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t i01_diff = i01_high - i01_low;
+
+    // compute
+    norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
+
+    (void) src1;
+    (void) dst;
+    (void) src0_ddq_i;
+    (void) src1_ddf_i;
+    (void) i02;
+    (void) i1;
+}
+
 inline void ggml_cuda_op_rms_norm(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
     float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
@@ -2930,6 +2987,11 @@ void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ten
     ggml_cuda_op(src0, src1, dst, ggml_cuda_op_silu, true, true);
 }
 
+void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
+    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_norm, true, true);
+}
+
 void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
     ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rms_norm, true, true);
@@ -3160,7 +3222,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
         }
 
 
-        cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);
+        CUDA_CHECK(cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice));
 
         extra->data_device[id] = buf;
 
@@ -3322,6 +3384,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
             }
             func = ggml_cuda_silu;
             break;
+        case GGML_OP_NORM:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cuda_norm;
+            break;
         case GGML_OP_RMS_NORM:
             if (!any_on_device) {
                 return false;