GGML_OP_SQR,
GGML_OP_SQRT,
GGML_OP_LOG,
+ GGML_OP_SIN,
+ GGML_OP_COS,
GGML_OP_SUM,
GGML_OP_SUM_ROWS,
GGML_OP_MEAN,
struct ggml_context * ctx,
struct ggml_tensor * a);
+ GGML_API struct ggml_tensor * ggml_sin(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ GGML_API struct ggml_tensor * ggml_sin_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ GGML_API struct ggml_tensor * ggml_cos(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
+ GGML_API struct ggml_tensor * ggml_cos_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a);
+
// return scalar
GGML_API struct ggml_tensor * ggml_sum(
struct ggml_context * ctx,
case GGML_OP_SQRT:
ggml_cuda_op_sqrt(ctx, dst);
break;
+ case GGML_OP_SIN:
+ ggml_cuda_op_sin(ctx, dst);
+ break;
+ case GGML_OP_COS:
+ ggml_cuda_op_cos(ctx, dst);
+ break;
case GGML_OP_CLAMP:
ggml_cuda_op_clamp(ctx, dst);
break;
case GGML_OP_SCALE:
case GGML_OP_SQR:
case GGML_OP_SQRT:
+ case GGML_OP_SIN:
+ case GGML_OP_COS:
case GGML_OP_CLAMP:
case GGML_OP_CONT:
case GGML_OP_DIAG_MASK_INF:
dst[i] = sqrtf(x[i]);
}
+static __global__ void sin_f32(const float * x, float * dst, const int k) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+ dst[i] = sinf(x[i]);
+}
+
+static __global__ void cos_f32(const float * x, float * dst, const int k) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+ dst[i] = cosf(x[i]);
+}
+
static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
sqrt_f32<<<num_blocks, CUDA_SQRT_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
+static void sin_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_SIN_BLOCK_SIZE - 1) / CUDA_SIN_BLOCK_SIZE;
+ sin_f32<<<num_blocks, CUDA_SIN_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
+static void cos_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_COS_BLOCK_SIZE - 1) / CUDA_COS_BLOCK_SIZE;
+ cos_f32<<<num_blocks, CUDA_COS_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}
+
+void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ sin_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
+
+void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const float * src0_d = (const float *)src0->data;
+ float * dst_d = (float *)dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ cos_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
+}
#define CUDA_HARDSWISH_BLOCK_SIZE 256
#define CUDA_SQR_BLOCK_SIZE 256
#define CUDA_SQRT_BLOCK_SIZE 256
+#define CUDA_SIN_BLOCK_SIZE 256
+#define CUDA_COS_BLOCK_SIZE 256
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
GGML_METAL_KERNEL_TYPE_CONCAT,
GGML_METAL_KERNEL_TYPE_SQR,
+ GGML_METAL_KERNEL_TYPE_SIN,
+ GGML_METAL_KERNEL_TYPE_COS,
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
GGML_METAL_KERNEL_TYPE_COUNT
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
}
case GGML_OP_REPEAT:
case GGML_OP_SCALE:
case GGML_OP_CLAMP:
+ return true;
case GGML_OP_SQR:
+ case GGML_OP_SIN:
+ case GGML_OP_COS:
+ return ggml_is_contiguous(op->src[0]);
case GGML_OP_SUM_ROWS:
- return true;
case GGML_OP_SOFT_MAX:
case GGML_OP_RMS_NORM:
case GGML_OP_GROUP_NORM:
const int64_t n = ggml_nelements(dst);
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_OP_SIN:
+ {
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SIN].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+
+ const int64_t n = ggml_nelements(dst);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_OP_COS:
+ {
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_COS].pipeline;
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+
+ const int64_t n = ggml_nelements(dst);
+
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_SUM_ROWS:
dst[tpig] = src0[tpig] * src0[tpig];
}
+kernel void kernel_sin(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = sin(src0[tpig]);
+}
+
+kernel void kernel_cos(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = cos(src0[tpig]);
+}
+
kernel void kernel_sum_rows(
device const float * src0,
device float * dst,
vk_pipeline pipeline_upscale_f32;
vk_pipeline pipeline_scale_f32;
vk_pipeline pipeline_sqr_f32;
+ vk_pipeline pipeline_sin_f32;
+ vk_pipeline pipeline_cos_f32;
vk_pipeline pipeline_clamp_f32;
vk_pipeline pipeline_pad_f32;
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
return ctx->device->pipeline_sqr_f32;
}
return nullptr;
+ case GGML_OP_SIN:
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_sin_f32;
+ }
+ return nullptr;
+ case GGML_OP_COS:
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_cos_f32;
+ }
+ return nullptr;
case GGML_OP_CLAMP:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_clamp_f32;
case GGML_OP_UPSCALE:
case GGML_OP_SCALE:
case GGML_OP_SQR:
+ case GGML_OP_SIN:
+ case GGML_OP_COS:
case GGML_OP_CLAMP:
case GGML_OP_PAD:
return true;
case GGML_OP_MUL:
case GGML_OP_SCALE:
case GGML_OP_SQR:
+ case GGML_OP_SIN:
+ case GGML_OP_COS:
case GGML_OP_CLAMP:
case GGML_OP_PAD:
case GGML_OP_CPY:
});
}
+static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
+
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, {
+ (uint32_t)ggml_nelements(src0),
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
+ 0,
+ 0.0f, 0.0f,
+ });
+}
+
+static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
+
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, {
+ (uint32_t)ggml_nelements(src0),
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
+ 0,
+ 0.0f, 0.0f,
+ });
+}
+
static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
float * op_params = (float *)dst->op_params;
const uint32_t src0_type_size = ggml_type_size(src0->type);
case GGML_OP_ADD:
case GGML_OP_SCALE:
case GGML_OP_SQR:
+ case GGML_OP_SIN:
+ case GGML_OP_COS:
case GGML_OP_CLAMP:
case GGML_OP_PAD:
case GGML_OP_CPY:
case GGML_OP_UPSCALE:
case GGML_OP_SCALE:
case GGML_OP_SQR:
+ case GGML_OP_SIN:
+ case GGML_OP_COS:
case GGML_OP_CLAMP:
case GGML_OP_PAD:
case GGML_OP_CPY:
case GGML_OP_SQR:
ggml_vk_sqr(ctx, compute_ctx, src0, node);
+ break;
+ case GGML_OP_SIN:
+ ggml_vk_sin(ctx, compute_ctx, src0, node);
+
+ break;
+ case GGML_OP_COS:
+ ggml_vk_cos(ctx, compute_ctx, src0, node);
+
break;
case GGML_OP_CLAMP:
ggml_vk_clamp(ctx, compute_ctx, src0, node);
case GGML_OP_UPSCALE:
case GGML_OP_SCALE:
case GGML_OP_SQR:
+ case GGML_OP_SIN:
+ case GGML_OP_COS:
case GGML_OP_CLAMP:
case GGML_OP_PAD:
case GGML_OP_CPY:
case GGML_OP_UPSCALE:
case GGML_OP_SCALE:
case GGML_OP_SQR:
+ case GGML_OP_SIN:
+ case GGML_OP_COS:
case GGML_OP_CLAMP:
case GGML_OP_PAD:
case GGML_OP_CONT:
inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); }
inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
-inline static void ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); }
+inline static void ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); }
+inline static void ggml_vec_sin_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sinf(x[i]); }
+inline static void ggml_vec_cos_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = cosf(x[i]); }
inline static void ggml_vec_abs_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); }
inline static void ggml_vec_sgn_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); }
inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; }
"SQR",
"SQRT",
"LOG",
+ "SIN",
+ "COS",
"SUM",
"SUM_ROWS",
"MEAN",
"CROSS_ENTROPY_LOSS_BACK",
};
-static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
+static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
"x^2",
"√x",
"log(x)",
+ "sin(x)",
+ "cos(x)",
"Σx",
"Σx_k",
"Σx/n",
"cross_entropy_loss_back(x,y)",
};
-static_assert(GGML_OP_COUNT == 74, "GGML_OP_COUNT != 74");
+static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
return ggml_log_impl(ctx, a, true);
}
+// ggml_sin
+
+static struct ggml_tensor * ggml_sin_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ bool inplace) {
+ bool is_node = false;
+
+ if (!inplace && (a->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_SIN;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_sin(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_sin_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_sin_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_sin_impl(ctx, a, true);
+}
+
+// ggml_cos
+
+static struct ggml_tensor * ggml_cos_impl(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ bool inplace) {
+ bool is_node = false;
+
+ if (!inplace && (a->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_COS;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
+struct ggml_tensor * ggml_cos(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_cos_impl(ctx, a, false);
+}
+
+struct ggml_tensor * ggml_cos_inplace(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a) {
+ return ggml_cos_impl(ctx, a, true);
+}
+
// ggml_sum
struct ggml_tensor * ggml_sum(
}
}
+// ggml_compute_forward_sin
+
+static void ggml_compute_forward_sin_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ GGML_ASSERT( dst->nb[0] == sizeof(float));
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_sin_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
+ }
+}
+
+static void ggml_compute_forward_sin(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_sin_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ABORT("fatal error");
+ }
+ }
+}
+
+// ggml_compute_forward_cos
+
+static void ggml_compute_forward_cos_f32(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ if (params->ith != 0) {
+ return;
+ }
+
+ GGML_ASSERT(ggml_are_same_shape(src0, dst));
+
+ const int n = ggml_nrows(src0);
+ const int nc = src0->ne[0];
+
+ GGML_ASSERT( dst->nb[0] == sizeof(float));
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
+
+ for (int i = 0; i < n; i++) {
+ ggml_vec_cos_f32(nc,
+ (float *) ((char *) dst->data + i*( dst->nb[1])),
+ (float *) ((char *) src0->data + i*(src0->nb[1])));
+ }
+}
+
+static void ggml_compute_forward_cos(
+ const struct ggml_compute_params * params,
+ struct ggml_tensor * dst) {
+
+ const struct ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_cos_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ABORT("fatal error");
+ }
+ }
+}
+
// ggml_compute_forward_sum
static void ggml_compute_forward_sum_f32(
{
ggml_compute_forward_log(params, tensor);
} break;
+ case GGML_OP_SIN:
+ {
+ ggml_compute_forward_sin(params, tensor);
+ } break;
+ case GGML_OP_COS:
+ {
+ ggml_compute_forward_cos(params, tensor);
+ } break;
case GGML_OP_SUM:
{
ggml_compute_forward_sum(params, tensor);
zero_table);
}
} break;
+ case GGML_OP_SIN:
+ {
+ if (src0->grad) {
+ src0->grad =
+ ggml_add_or_set(ctx,
+ src0->grad,
+ ggml_mul(ctx,
+ tensor->grad,
+ ggml_cos(ctx, src0)),
+ zero_table);
+ }
+ } break;
+ case GGML_OP_COS:
+ {
+ if (src0->grad) {
+ src0->grad =
+ ggml_sub_or_set(ctx,
+ src0->grad,
+ ggml_mul(ctx,
+ tensor->grad,
+ ggml_sin(ctx, src0)),
+ zero_table);
+ }
+ } break;
case GGML_OP_SUM:
{
if (src0->grad) {
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_LOG:
+ case GGML_OP_SIN:
+ case GGML_OP_COS:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
--- /dev/null
+#version 450
+
+#include "types.comp"
+#include "generic_unary_head.comp"
+
+void main() {
+ const uint idx = get_idx();
+
+ if (idx >= p.ne) {
+ return;
+ }
+
+ const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
+ data_d[p.d_offset + dst_idx(idx)] = D_TYPE(cos(val));
+}
--- /dev/null
+#version 450
+
+#include "types.comp"
+#include "generic_unary_head.comp"
+
+void main() {
+ const uint idx = get_idx();
+
+ if (idx >= p.ne) {
+ return;
+ }
+
+ const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
+ data_d[p.d_offset + dst_idx(idx)] = D_TYPE(sin(val));
+}
string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
}));
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
+ }));
+
+ tasks.push_back(std::async(std::launch::async, [] {
+ string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
+ }));
+
tasks.push_back(std::async(std::launch::async, [] {
string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
}));
}
};
+// GGML_OP_SIN
+struct test_sin : public test_case {
+ const ggml_type type;
+ const std::array<int64_t, 4> ne;
+
+ std::string vars() override {
+ return VARS_TO_STR2(type, ne);
+ }
+
+ test_sin(ggml_type type = GGML_TYPE_F32,
+ std::array<int64_t, 4> ne = {10, 10, 10, 10})
+ : type(type), ne(ne) {}
+
+ ggml_tensor * build_graph(ggml_context * ctx) override {
+ ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+ ggml_tensor * out = ggml_sin(ctx, a);
+ return out;
+ }
+
+ void initialize_tensors(ggml_context * ctx) override {
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+ init_tensor_uniform(t, -100.0f, 100.0f);
+ }
+ }
+};
+
+// GGML_OP_COS
+struct test_cos : public test_case {
+ const ggml_type type;
+ const std::array<int64_t, 4> ne;
+
+ std::string vars() override {
+ return VARS_TO_STR2(type, ne);
+ }
+
+ test_cos(ggml_type type = GGML_TYPE_F32,
+ std::array<int64_t, 4> ne = {10, 10, 10, 10})
+ : type(type), ne(ne) {}
+
+ ggml_tensor * build_graph(ggml_context * ctx) override {
+ ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+ ggml_tensor * out = ggml_cos(ctx, a);
+ return out;
+ }
+
+ void initialize_tensors(ggml_context * ctx) override {
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+ init_tensor_uniform(t, -100.0f, 100.0f);
+ }
+ }
+};
+
// GGML_OP_CLAMP
struct test_clamp : public test_case {
const ggml_type type;
test_cases.emplace_back(new test_sqr());
test_cases.emplace_back(new test_sqrt());
+ test_cases.emplace_back(new test_sin());
+ test_cases.emplace_back(new test_cos());
test_cases.emplace_back(new test_clamp());
test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5));