]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
feat: implemented sigmoid function (ggml/806)
authorJustina Cho <redacted>
Wed, 1 May 2024 21:44:26 +0000 (14:44 -0700)
committerGeorgi Gerganov <redacted>
Mon, 13 May 2024 08:02:26 +0000 (11:02 +0300)
* added sigmoid function

* implemented metal kernel for sigmoid

* implemented cuda kernel for sigmoid

* added sigmoid unary op and incremented count

ggml-cuda.cu
ggml-cuda/unary.cu
ggml-cuda/unary.cuh
ggml-metal.m
ggml-metal.metal
ggml.c
ggml.h

index bff8ad9d96e887aaf0c8206955642d1bcbf22eec..4a2bbdabfa70c2c45eaec7402a896616bb0647dd 100644 (file)
@@ -2115,6 +2115,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
                 case GGML_UNARY_OP_RELU:
                     ggml_cuda_op_relu(ctx, dst);
                     break;
+                case GGML_UNARY_OP_SIGMOID:
+                    ggml_cuda_op_sigmoid(ctx, dst);
+                    break;
                 case GGML_UNARY_OP_HARDSIGMOID:
                     ggml_cuda_op_hardsigmoid(ctx, dst);
                     break;
@@ -2355,6 +2358,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
                 case GGML_UNARY_OP_GELU:
                 case GGML_UNARY_OP_SILU:
                 case GGML_UNARY_OP_RELU:
+                case GGML_UNARY_OP_SIGMOID:
                 case GGML_UNARY_OP_HARDSIGMOID:
                 case GGML_UNARY_OP_HARDSWISH:
                 case GGML_UNARY_OP_GELU_QUICK:
index 1a7f0946972c199e5f51a49c174f4518648c4cc1..ac03d5c6fce546b82822248887f3a356e6d2320c 100644 (file)
@@ -48,6 +48,15 @@ static __global__ void relu_f32(const float * x, float * dst, const int k) {
     dst[i] = fmaxf(x[i], 0);
 }
 
+static __global__ void sigmoid_f32(const float * x, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = 1.0f / (1.0f + expf(-x[i]));
+}
+
 static __global__ void hardsigmoid_f32(const float * x, float * dst, const int k) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
@@ -108,6 +117,11 @@ static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_
     relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
+static void sigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_SIGMOID_BLOCK_SIZE - 1) / CUDA_SIGMOID_BLOCK_SIZE;
+    sigmoid_f32<<<num_blocks, CUDA_SIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
 static void hardsigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_HARDSIGMOID_BLOCK_SIZE - 1) / CUDA_HARDSIGMOID_BLOCK_SIZE;
     hardsigmoid_f32<<<num_blocks, CUDA_HARDSIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@@ -188,6 +202,18 @@ void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     relu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
 }
 
+void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    sigmoid_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
+
 void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const float * src0_d = (const float *)src0->data;
index 2002ed989209ca119cb245d9a14840e04f5beb87..a1d07c04fcd4350a321690dbb6f824c8fdf7a2df 100644 (file)
@@ -4,6 +4,7 @@
 #define CUDA_SILU_BLOCK_SIZE 256
 #define CUDA_TANH_BLOCK_SIZE 256
 #define CUDA_RELU_BLOCK_SIZE 256
+#define CUDA_SIGMOID_BLOCK_SIZE 256
 #define CUDA_HARDSIGMOID_BLOCK_SIZE 256
 #define CUDA_HARDSWISH_BLOCK_SIZE 256
 #define CUDA_SQR_BLOCK_SIZE 256
@@ -18,6 +19,8 @@ void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
+void ggml_cuda_op_sigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
 void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 419d8b9e56878f7638c984f129225dc6f3474e1d..86426e933e04b84bcac1c1970feb9ab3298b1048 100644 (file)
@@ -39,6 +39,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_SCALE_4,
     GGML_METAL_KERNEL_TYPE_TANH,
     GGML_METAL_KERNEL_TYPE_RELU,
+    GGML_METAL_KERNEL_TYPE_SIGMOID,
     GGML_METAL_KERNEL_TYPE_GELU,
     GGML_METAL_KERNEL_TYPE_GELU_QUICK,
     GGML_METAL_KERNEL_TYPE_SILU,
@@ -470,6 +471,7 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4,                   scale_4,                true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH,                      tanh,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU,                      relu,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIGMOID,                   sigmoid,                true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU,                      gelu,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK,                gelu_quick,             true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU,                      silu,                   true);
@@ -695,6 +697,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
             switch (ggml_get_unary_op(op)) {
                 case GGML_UNARY_OP_TANH:
                 case GGML_UNARY_OP_RELU:
+                case GGML_UNARY_OP_SIGMOID:
                 case GGML_UNARY_OP_GELU:
                 case GGML_UNARY_OP_GELU_QUICK:
                 case GGML_UNARY_OP_SILU:
@@ -1178,6 +1181,18 @@ static enum ggml_status ggml_metal_graph_compute(
 
                                 const int64_t n = ggml_nelements(dst);
 
+                                [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                            } break;
+                        case GGML_UNARY_OP_SIGMOID:
+                            {
+                                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIGMOID].pipeline;
+
+                                [encoder setComputePipelineState:pipeline];
+                                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
+
+                                const int64_t n = ggml_nelements(dst);
+
                                 [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                             } break;
                         case GGML_UNARY_OP_GELU:
index 9a29f57a38c6b7cf501e09e034c6f66631074cdf..7f840ab089b9b18b6a1833aa6014b33666e2bffd 100644 (file)
@@ -220,6 +220,13 @@ kernel void kernel_relu(
     dst[tpig] = max(0.0f, src0[tpig]);
 }
 
+kernel void kernel_sigmoid(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
+}
+
 kernel void kernel_tanh(
         device const float * src0,
         device       float * dst,
diff --git a/ggml.c b/ggml.c
index 793b67f4c70209e37098c26d1833abd3d5a9b4a6..3256dda8a08e65f550fca7e3006eba1d78252274 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -1763,6 +1763,7 @@ inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) {
 inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]);  }
 inline static void ggml_vec_elu_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; }
 inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
+inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); }
 inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
 // TODO: optimize performance
 inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); }
@@ -2136,6 +2137,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
     "TANH",
     "ELU",
     "RELU",
+    "SIGMOID",
     "GELU",
     "GELU_QUICK",
     "SILU",
@@ -2143,7 +2145,7 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
     "HARDSIGMOID",
 };
 
-static_assert(GGML_UNARY_OP_COUNT == 12, "GGML_UNARY_OP_COUNT != 12");
+static_assert(GGML_UNARY_OP_COUNT == 13, "GGML_UNARY_OP_COUNT != 13");
 
 
 static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@@ -4295,6 +4297,20 @@ struct ggml_tensor * ggml_relu_inplace(
     return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU);
 }
 
+// ggml_sigmoid
+
+struct ggml_tensor * ggml_sigmoid(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_unary(ctx, a, GGML_UNARY_OP_SIGMOID);
+}
+
+struct ggml_tensor * ggml_sigmoid_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SIGMOID);
+}
+
 // ggml_leaky_relu
 
 struct ggml_tensor * ggml_leaky_relu(
@@ -9838,6 +9854,52 @@ static void ggml_compute_forward_relu(
     }
 }
 
+// ggml_compute_forward_sigmoid
+
+static void ggml_compute_forward_sigmoid_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    assert(params->ith == 0);
+    assert(ggml_are_same_shape(src0, dst));
+
+    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
+        return;
+    }
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    assert(dst->nb[0]  == sizeof(float));
+    assert(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_sigmoid_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_sigmoid(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_sigmoid_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ASSERT(false);
+            } break;
+    }
+}
+
 // ggml_compute_forward_gelu
 
 static void ggml_compute_forward_gelu_f32(
@@ -15485,6 +15547,10 @@ static void ggml_compute_forward_unary(
             {
                 ggml_compute_forward_relu(params, dst);
             } break;
+        case GGML_UNARY_OP_SIGMOID:
+            {
+                ggml_compute_forward_sigmoid(params, dst);
+            } break;
         case GGML_UNARY_OP_GELU:
             {
                 ggml_compute_forward_gelu(params, dst);
@@ -17471,6 +17537,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                         zero_table);
                             }
                         } break;
+                    case GGML_UNARY_OP_SIGMOID:
+                        {
+                            GGML_ASSERT(false); // TODO: not implemented
+                        } break;
                     case GGML_UNARY_OP_GELU:
                         {
                             GGML_ASSERT(false); // TODO: not implemented
@@ -18000,6 +18070,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
                 case GGML_UNARY_OP_TANH:
                 case GGML_UNARY_OP_ELU:
                 case GGML_UNARY_OP_RELU:
+                case GGML_UNARY_OP_SIGMOID:
                 case GGML_UNARY_OP_HARDSWISH: // to opt for multiple threads
                 case GGML_UNARY_OP_HARDSIGMOID: // to opt for multiple threads
                     {
diff --git a/ggml.h b/ggml.h
index abe3767f22418b9580bee4c00b435692a94fbb27..fbc34f0c9d0d089da5e9673d513e7e8887434fed 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -511,6 +511,7 @@ extern "C" {
         GGML_UNARY_OP_TANH,
         GGML_UNARY_OP_ELU,
         GGML_UNARY_OP_RELU,
+        GGML_UNARY_OP_SIGMOID,
         GGML_UNARY_OP_GELU,
         GGML_UNARY_OP_GELU_QUICK,
         GGML_UNARY_OP_SILU,
@@ -1055,6 +1056,10 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    GGML_API struct ggml_tensor * ggml_sigmoid(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
     GGML_API struct ggml_tensor * ggml_leaky_relu(
             struct ggml_context * ctx,
             struct ggml_tensor  * a, float negative_slope, bool inplace);
@@ -1063,6 +1068,10 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    GGML_API struct ggml_tensor * ggml_sigmoid_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
     GGML_API struct ggml_tensor * ggml_gelu(
             struct ggml_context * ctx,
             struct ggml_tensor  * a);