]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
cuda : add gelu support
authorGeorgi Gerganov <redacted>
Wed, 12 Jul 2023 17:26:18 +0000 (20:26 +0300)
committerGeorgi Gerganov <redacted>
Wed, 12 Jul 2023 17:32:15 +0000 (20:32 +0300)
ggml-cuda.cu

index 89e69bdc14e1c4efe6c2b64455859584808afe2a..dc4b773a66a4451dab6f209cc44751dc852adaeb 100644 (file)
@@ -212,6 +212,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
 
 #define CUDA_ADD_BLOCK_SIZE 256
 #define CUDA_MUL_BLOCK_SIZE 256
+#define CUDA_GELU_BLOCK_SIZE 256
 #define CUDA_SILU_BLOCK_SIZE 256
 #define CUDA_CPY_BLOCK_SIZE 32
 #define CUDA_SCALE_BLOCK_SIZE 256
@@ -266,6 +267,20 @@ static __global__ void mul_f32(const float * x, const float * y, float * dst, co
     dst[i] = x[i] * y[i%ky];
 }
 
+static const float GELU_COEF_A    = 0.044715f;
+static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
+
+static __global__ void gelu_f32(const float * x, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+
+    float xi = x[i];
+    dst[i] = 0.5f*xi*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi)));
+}
+
 static __global__ void silu_f32(const float * x, float * dst, const int k) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
@@ -1733,6 +1748,11 @@ static void mul_f32_cuda(const float * x, const float * y, float * dst, const in
     mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
 }
 
+static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
+    gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
 static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
     silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@@ -2327,6 +2347,28 @@ inline void ggml_cuda_op_mul(
     (void) i02;
 }
 
+inline void ggml_cuda_op_gelu(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
+    cudaStream_t & cudaStream_main){
+
+    GGML_ASSERT(src0_ddf_i != nullptr);
+    GGML_ASSERT(dst_ddf_i != nullptr);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t i01_diff = i01_high - i01_low;
+
+    // compute
+    gelu_f32_cuda(src0_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main);
+
+    (void) src1;
+    (void) dst;
+    (void) src0_ddq_i;
+    (void) src1_ddf_i;
+    (void) i02;
+    (void) i1;
+}
+
 inline void ggml_cuda_op_silu(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
     float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
@@ -2986,6 +3028,11 @@ void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
     ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul, true, false); // TODO ggml_cuda_op needs modification for flatten
 }
 
+void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
+    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_gelu, true, true);
+}
+
 void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
     ggml_cuda_op(src0, src1, dst, ggml_cuda_op_silu, true, true);
@@ -3382,6 +3429,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
             }
             func = ggml_cuda_mul;
             break;
+        case GGML_OP_GELU:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cuda_gelu;
+            break;
         case GGML_OP_SILU:
             if (!any_on_device) {
                 return false;