]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
feat: add new `sin` and `cos` operators (#919)
authorRonsor <redacted>
Mon, 12 Aug 2024 13:02:08 +0000 (06:02 -0700)
committerGitHub <redacted>
Mon, 12 Aug 2024 13:02:08 +0000 (15:02 +0200)
* ggml : add sin/cos operators

* ggml-cuda : add sin/cos operators

* ggml : add corresponding tests for sin/cos

* ggml : add backward computation for sin/cos operators

* ggml-vulkan : add sin/cos operators

* ggml-vulkan : add sin/cos shader source

* metal : add sin, cos

---------

Co-authored-by: Georgi Gerganov <redacted>
12 files changed:
include/ggml.h
src/ggml-cuda.cu
src/ggml-cuda/unary.cu
src/ggml-cuda/unary.cuh
src/ggml-metal.m
src/ggml-metal.metal
src/ggml-vulkan.cpp
src/ggml.c
src/vulkan-shaders/cos.comp [new file with mode: 0644]
src/vulkan-shaders/sin.comp [new file with mode: 0644]
src/vulkan-shaders/vulkan-shaders-gen.cpp
tests/test-backend-ops.cpp

index 15602a96df7ad3ef4df675d2a786a51d888c6cdb..4ea7aa91124aa8166e60dda87f315c11f9e81481 100644 (file)
@@ -451,6 +451,8 @@ extern "C" {
         GGML_OP_SQR,
         GGML_OP_SQRT,
         GGML_OP_LOG,
+        GGML_OP_SIN,
+        GGML_OP_COS,
         GGML_OP_SUM,
         GGML_OP_SUM_ROWS,
         GGML_OP_MEAN,
@@ -967,6 +969,22 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    GGML_API struct ggml_tensor * ggml_sin(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_sin_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_cos(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
+    GGML_API struct ggml_tensor * ggml_cos_inplace(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a);
+
     // return scalar
     GGML_API struct ggml_tensor * ggml_sum(
             struct ggml_context * ctx,
index 682c30d45bcf4323a4e2506ed4b2ecacf7d63e99..8ff154f729ed612f8bc6241fe09348a3936dfddd 100644 (file)
@@ -2267,6 +2267,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_SQRT:
             ggml_cuda_op_sqrt(ctx, dst);
             break;
+        case GGML_OP_SIN:
+            ggml_cuda_op_sin(ctx, dst);
+            break;
+        case GGML_OP_COS:
+            ggml_cuda_op_cos(ctx, dst);
+            break;
         case GGML_OP_CLAMP:
             ggml_cuda_op_clamp(ctx, dst);
             break;
@@ -2859,6 +2865,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
         case GGML_OP_SCALE:
         case GGML_OP_SQR:
         case GGML_OP_SQRT:
+        case GGML_OP_SIN:
+        case GGML_OP_COS:
         case GGML_OP_CLAMP:
         case GGML_OP_CONT:
         case GGML_OP_DIAG_MASK_INF:
index f9e208011e2a8f2b1387d2c2173e8056bf0ad068..89abfc21d8a56c85bb191fd75270cf713a4a7b64 100644 (file)
@@ -101,6 +101,24 @@ static __global__ void sqrt_f32(const float * x, float * dst, const int k) {
     dst[i] = sqrtf(x[i]);
 }
 
+static __global__ void sin_f32(const float * x, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = sinf(x[i]);
+}
+
+static __global__ void cos_f32(const float * x, float * dst, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+    dst[i] = cosf(x[i]);
+}
+
 static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
     gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
@@ -156,6 +174,16 @@ static void sqrt_f32_cuda(const float * x, float * dst, const int k, cudaStream_
     sqrt_f32<<<num_blocks, CUDA_SQRT_BLOCK_SIZE, 0, stream>>>(x, dst, k);
 }
 
+static void sin_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_SIN_BLOCK_SIZE - 1) / CUDA_SIN_BLOCK_SIZE;
+    sin_f32<<<num_blocks, CUDA_SIN_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
+static void cos_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_COS_BLOCK_SIZE - 1) / CUDA_COS_BLOCK_SIZE;
+    cos_f32<<<num_blocks, CUDA_COS_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
 void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const float * src0_d = (const float *)src0->data;
@@ -312,3 +340,31 @@ void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 
     sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
 }
+
+void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    sin_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
+
+void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const float * src0_d = (const float *)src0->data;
+    float * dst_d = (float *)dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    cos_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
index 4cfb0479e7169a83a8b433b21b9b381aeb9158d9..c610e996abeb62a2fdcc2ea424774513f7e5c443 100644 (file)
@@ -9,6 +9,8 @@
 #define CUDA_HARDSWISH_BLOCK_SIZE 256
 #define CUDA_SQR_BLOCK_SIZE 256
 #define CUDA_SQRT_BLOCK_SIZE 256
+#define CUDA_SIN_BLOCK_SIZE 256
+#define CUDA_COS_BLOCK_SIZE 256
 
 void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
@@ -31,3 +33,7 @@ void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
 void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index aad189430ab0b3f425de57b11b5cd528656906e7..f6bd6e3407e54f093218e4eedcb51a8585842860 100644 (file)
@@ -205,6 +205,8 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
     GGML_METAL_KERNEL_TYPE_CONCAT,
     GGML_METAL_KERNEL_TYPE_SQR,
+    GGML_METAL_KERNEL_TYPE_SIN,
+    GGML_METAL_KERNEL_TYPE_COS,
     GGML_METAL_KERNEL_TYPE_SUM_ROWS,
 
     GGML_METAL_KERNEL_TYPE_COUNT
@@ -665,6 +667,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,                cpy_f32_iq4_nl,                 true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT,                        concat,                         true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR,                           sqr,                            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN,                           sin,                            true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS,                           cos,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                      sum_rows,                       true);
     }
 
@@ -771,9 +775,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
         case GGML_OP_REPEAT:
         case GGML_OP_SCALE:
         case GGML_OP_CLAMP:
+            return true;
         case GGML_OP_SQR:
+        case GGML_OP_SIN:
+        case GGML_OP_COS:
+            return ggml_is_contiguous(op->src[0]);
         case GGML_OP_SUM_ROWS:
-            return true;
         case GGML_OP_SOFT_MAX:
         case GGML_OP_RMS_NORM:
         case GGML_OP_GROUP_NORM:
@@ -1409,6 +1416,34 @@ static enum ggml_status ggml_metal_graph_compute(
 
                         const int64_t n = ggml_nelements(dst);
 
+                        [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                    } break;
+                case GGML_OP_SIN:
+                    {
+                        GGML_ASSERT(ggml_is_contiguous(src0));
+
+                        id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIN].pipeline;
+
+                        [encoder setComputePipelineState:pipeline];
+                        [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                        [encoder setBuffer:id_dst  offset:offs_dst atIndex:1];
+
+                        const int64_t n = ggml_nelements(dst);
+
+                        [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                    } break;
+                case GGML_OP_COS:
+                    {
+                        GGML_ASSERT(ggml_is_contiguous(src0));
+
+                        id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_COS].pipeline;
+
+                        [encoder setComputePipelineState:pipeline];
+                        [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                        [encoder setBuffer:id_dst  offset:offs_dst atIndex:1];
+
+                        const int64_t n = ggml_nelements(dst);
+
                         [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                     } break;
                 case GGML_OP_SUM_ROWS:
index 3bb37d32aced02644401fb10587ffa8a6165eb52..3e4b685bb51c4aa43cf79744bbbddd82752138a1 100644 (file)
@@ -358,6 +358,20 @@ kernel void kernel_sqr(
     dst[tpig] = src0[tpig] * src0[tpig];
 }
 
+kernel void kernel_sin(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = sin(src0[tpig]);
+}
+
+kernel void kernel_cos(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = cos(src0[tpig]);
+}
+
 kernel void kernel_sum_rows(
         device const float * src0,
         device       float * dst,
index b0f36a513f84bdd851a640462addd24d70431e7e..268fa3cd35ec584d73c9063d53b3d75f7a7c487c 100644 (file)
@@ -184,6 +184,8 @@ struct vk_device_struct {
     vk_pipeline pipeline_upscale_f32;
     vk_pipeline pipeline_scale_f32;
     vk_pipeline pipeline_sqr_f32;
+    vk_pipeline pipeline_sin_f32;
+    vk_pipeline pipeline_cos_f32;
     vk_pipeline pipeline_clamp_f32;
     vk_pipeline pipeline_pad_f32;
     vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
@@ -1654,6 +1656,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
@@ -3972,6 +3976,16 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
             return ctx->device->pipeline_sqr_f32;
         }
         return nullptr;
+    case GGML_OP_SIN:
+        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+            return ctx->device->pipeline_sin_f32;
+        }
+        return nullptr;
+    case GGML_OP_COS:
+        if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+            return ctx->device->pipeline_cos_f32;
+        }
+        return nullptr;
     case GGML_OP_CLAMP:
         if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
             return ctx->device->pipeline_clamp_f32;
@@ -4124,6 +4138,8 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
     case GGML_OP_UPSCALE:
     case GGML_OP_SCALE:
     case GGML_OP_SQR:
+    case GGML_OP_SIN:
+    case GGML_OP_COS:
     case GGML_OP_CLAMP:
     case GGML_OP_PAD:
         return true;
@@ -4335,6 +4351,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
         case GGML_OP_MUL:
         case GGML_OP_SCALE:
         case GGML_OP_SQR:
+        case GGML_OP_SIN:
+        case GGML_OP_COS:
         case GGML_OP_CLAMP:
         case GGML_OP_PAD:
         case GGML_OP_CPY:
@@ -4576,6 +4594,32 @@ static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const
     });
 }
 
+static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+    const uint32_t src0_type_size = ggml_type_size(src0->type);
+    const uint32_t dst_type_size = ggml_type_size(dst->type);
+
+    ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, {
+        (uint32_t)ggml_nelements(src0),
+        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
+        0,
+        0.0f, 0.0f,
+    });
+}
+
+static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+    const uint32_t src0_type_size = ggml_type_size(src0->type);
+    const uint32_t dst_type_size = ggml_type_size(dst->type);
+
+    ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, {
+        (uint32_t)ggml_nelements(src0),
+        (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+        (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
+        0,
+        0.0f, 0.0f,
+    });
+}
+
 static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
     float * op_params = (float *)dst->op_params;
     const uint32_t src0_type_size = ggml_type_size(src0->type);
@@ -5481,6 +5525,8 @@ static void ggml_vk_preallocate_buffers_graph(ggml_backend_vk_context * ctx, ggm
     case GGML_OP_ADD:
     case GGML_OP_SCALE:
     case GGML_OP_SQR:
+    case GGML_OP_SIN:
+    case GGML_OP_COS:
     case GGML_OP_CLAMP:
     case GGML_OP_PAD:
     case GGML_OP_CPY:
@@ -5761,6 +5807,8 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
     case GGML_OP_UPSCALE:
     case GGML_OP_SCALE:
     case GGML_OP_SQR:
+    case GGML_OP_SIN:
+    case GGML_OP_COS:
     case GGML_OP_CLAMP:
     case GGML_OP_PAD:
     case GGML_OP_CPY:
@@ -5832,6 +5880,14 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
     case GGML_OP_SQR:
         ggml_vk_sqr(ctx, compute_ctx, src0, node);
 
+        break;
+    case GGML_OP_SIN:
+        ggml_vk_sin(ctx, compute_ctx, src0, node);
+
+        break;
+    case GGML_OP_COS:
+        ggml_vk_cos(ctx, compute_ctx, src0, node);
+
         break;
     case GGML_OP_CLAMP:
         ggml_vk_clamp(ctx, compute_ctx, src0, node);
@@ -5943,6 +5999,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
     case GGML_OP_UPSCALE:
     case GGML_OP_SCALE:
     case GGML_OP_SQR:
+    case GGML_OP_SIN:
+    case GGML_OP_COS:
     case GGML_OP_CLAMP:
     case GGML_OP_PAD:
     case GGML_OP_CPY:
@@ -6658,6 +6716,8 @@ GGML_CALL static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const
         case GGML_OP_UPSCALE:
         case GGML_OP_SCALE:
         case GGML_OP_SQR:
+        case GGML_OP_SIN:
+        case GGML_OP_COS:
         case GGML_OP_CLAMP:
         case GGML_OP_PAD:
         case GGML_OP_CONT:
index 74e1c596bb1976221a7144d663049e06b3f4139e..a56c2ffd9e2de38745ca62201f8cf0d46a932094 100644 (file)
@@ -2310,7 +2310,9 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float
 inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s);   }
 inline static void ggml_vec_sqr_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i];   }
 inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
-inline static void ggml_vec_log_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]);   }
+inline static void ggml_vec_log_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]);  }
+inline static void ggml_vec_sin_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sinf(x[i]);  }
+inline static void ggml_vec_cos_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = cosf(x[i]);  }
 inline static void ggml_vec_abs_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
 inline static void ggml_vec_sgn_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
 inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
@@ -2760,6 +2762,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "SQR",
     "SQRT",
     "LOG",
+    "SIN",
+    "COS",
     "SUM",
     "SUM_ROWS",
     "MEAN",
@@ -2833,7 +2837,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "CROSS_ENTROPY_LOSS_BACK",
 };
 
-static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
+static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -2848,6 +2852,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "x^2",
     "√x",
     "log(x)",
+    "sin(x)",
+    "cos(x)",
     "Σx",
     "Σx_k",
     "Σx/n",
@@ -2921,7 +2927,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "cross_entropy_loss_back(x,y)",
 };
 
-static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
+static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -4882,6 +4888,72 @@ struct ggml_tensor * ggml_log_inplace(
     return ggml_log_impl(ctx, a, true);
 }
 
+// ggml_sin
+
+static struct ggml_tensor * ggml_sin_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_SIN;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = a;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_sin(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_sin_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_sin_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_sin_impl(ctx, a, true);
+}
+
+// ggml_cos
+
+static struct ggml_tensor * ggml_cos_impl(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        bool inplace) {
+    bool is_node = false;
+
+    if (!inplace && (a->grad)) {
+        is_node = true;
+    }
+
+    struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+    result->op   = GGML_OP_COS;
+    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+    result->src[0] = a;
+
+    return result;
+}
+
+struct ggml_tensor * ggml_cos(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_cos_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_cos_inplace(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a) {
+    return ggml_cos_impl(ctx, a, true);
+}
+
 // ggml_sum
 
 struct ggml_tensor * ggml_sum(
@@ -10512,6 +10584,96 @@ static void ggml_compute_forward_log(
     }
 }
 
+// ggml_compute_forward_sin
+
+static void ggml_compute_forward_sin_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    GGML_ASSERT( dst->nb[0] == sizeof(float));
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_sin_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_sin(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_sin_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
+// ggml_compute_forward_cos
+
+static void ggml_compute_forward_cos_f32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    if (params->ith != 0) {
+        return;
+    }
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+    const int n  = ggml_nrows(src0);
+    const int nc = src0->ne[0];
+
+    GGML_ASSERT( dst->nb[0] == sizeof(float));
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+    for (int i = 0; i < n; i++) {
+        ggml_vec_cos_f32(nc,
+                (float *) ((char *) dst->data  + i*( dst->nb[1])),
+                (float *) ((char *) src0->data + i*(src0->nb[1])));
+    }
+}
+
+static void ggml_compute_forward_cos(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_cos_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
 // ggml_compute_forward_sum
 
 static void ggml_compute_forward_sum_f32(
@@ -16787,6 +16949,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_log(params, tensor);
             } break;
+        case GGML_OP_SIN:
+            {
+                ggml_compute_forward_sin(params, tensor);
+            } break;
+        case GGML_OP_COS:
+            {
+                ggml_compute_forward_cos(params, tensor);
+            } break;
         case GGML_OP_SUM:
             {
                 ggml_compute_forward_sum(params, tensor);
@@ -17433,6 +17603,30 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                                 zero_table);
                 }
             } break;
+        case GGML_OP_SIN:
+            {
+                if (src0->grad) {
+                    src0->grad =
+                        ggml_add_or_set(ctx,
+                                src0->grad,
+                                ggml_mul(ctx,
+                                    tensor->grad,
+                                    ggml_cos(ctx, src0)),
+                                zero_table);
+                }
+            } break;
+        case GGML_OP_COS:
+            {
+                if (src0->grad) {
+                    src0->grad =
+                        ggml_sub_or_set(ctx,
+                                src0->grad,
+                                ggml_mul(ctx,
+                                    tensor->grad,
+                                    ggml_sin(ctx, src0)),
+                                zero_table);
+                }
+            } break;
         case GGML_OP_SUM:
             {
                 if (src0->grad) {
@@ -18520,6 +18714,8 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
         case GGML_OP_SQR:
         case GGML_OP_SQRT:
         case GGML_OP_LOG:
+        case GGML_OP_SIN:
+        case GGML_OP_COS:
         case GGML_OP_SUM:
         case GGML_OP_SUM_ROWS:
         case GGML_OP_MEAN:
diff --git a/src/vulkan-shaders/cos.comp b/src/vulkan-shaders/cos.comp
new file mode 100644 (file)
index 0000000..f9a858c
--- /dev/null
@@ -0,0 +1,15 @@
+#version 450
+
+#include "types.comp"
+#include "generic_unary_head.comp"
+
+void main() {
+    const uint idx = get_idx();
+
+    if (idx >= p.ne) {
+        return;
+    }
+
+    const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
+    data_d[p.d_offset + dst_idx(idx)] = D_TYPE(cos(val));
+}
diff --git a/src/vulkan-shaders/sin.comp b/src/vulkan-shaders/sin.comp
new file mode 100644 (file)
index 0000000..7faf9be
--- /dev/null
@@ -0,0 +1,15 @@
+#version 450
+
+#include "types.comp"
+#include "generic_unary_head.comp"
+
+void main() {
+    const uint idx = get_idx();
+
+    if (idx >= p.ne) {
+        return;
+    }
+
+    const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
+    data_d[p.d_offset + dst_idx(idx)] = D_TYPE(sin(val));
+}
index a792e203b273a018fa779555896b4dcf064731c4..a451d24705136572b8f58b03f505835f6895233a 100644 (file)
@@ -388,6 +388,14 @@ void process_shaders(std::vector<std::future<void>>& tasks) {
         string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
     }));
 
+    tasks.push_back(std::async(std::launch::async, [] {
+        string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
+    }));
+
+    tasks.push_back(std::async(std::launch::async, [] {
+        string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
+    }));
+
     tasks.push_back(std::async(std::launch::async, [] {
         string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
     }));
index 2f4117a627de0f5582626da0893558237934aeb0..01702b109db944bc60f9ae9b3fe8417bc31e0b57 100644 (file)
@@ -1108,6 +1108,58 @@ struct test_sqrt : public test_case {
     }
 };
 
+// GGML_OP_SIN
+struct test_sin : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_sin(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 10, 10, 10})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_sin(ctx, a);
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            init_tensor_uniform(t, -100.0f, 100.0f);
+        }
+    }
+};
+
+// GGML_OP_COS
+struct test_cos : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+
+    std::string vars() override {
+        return VARS_TO_STR2(type, ne);
+    }
+
+    test_cos(ggml_type type = GGML_TYPE_F32,
+            std::array<int64_t, 4> ne = {10, 10, 10, 10})
+        : type(type), ne(ne) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+        ggml_tensor * out = ggml_cos(ctx, a);
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            init_tensor_uniform(t, -100.0f, 100.0f);
+        }
+    }
+};
+
 // GGML_OP_CLAMP
 struct test_clamp : public test_case {
     const ggml_type type;
@@ -2321,6 +2373,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
 
     test_cases.emplace_back(new test_sqr());
     test_cases.emplace_back(new test_sqrt());
+    test_cases.emplace_back(new test_sin());
+    test_cases.emplace_back(new test_cos());
     test_cases.emplace_back(new test_clamp());
 
     test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10,  1,  1}, 5));