GGML_UNARY_OP_HARDSIGMOID,
GGML_UNARY_OP_EXP,
GGML_UNARY_OP_GELU_ERF,
+ GGML_UNARY_OP_XIELU,
GGML_UNARY_OP_COUNT,
};
struct ggml_context * ctx,
struct ggml_tensor * a);
+ // xIELU activation function
+ // x = x * (c_a(alpha_n) + c_b(alpha_p, beta) * sigmoid(beta * x)) + eps * (x > 0)
+ // where c_a = softplus and c_b(a, b) = softplus(a) + b are constraining functions
+ // that constrain the positive and negative source alpha values respectively
+ GGML_API struct ggml_tensor * ggml_xielu(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ float alpha_n,
+ float alpha_p,
+ float beta,
+ float eps);
+
// gated linear unit ops
// A: n columns, r rows,
// result is n / 2 columns, r rows,
case GGML_UNARY_OP_GELU_ERF:
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_SILU:
+ case GGML_UNARY_OP_XIELU:
{
n_tasks = n_threads;
} break;
// n_head
for (int h = ih0; h < ih1; ++h) {
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
- const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
+ const float dt_soft_plus = ggml_softplus(dt[h]);
const float dA = expf(dt_soft_plus * A[h]);
const int g = h / (nh / ng); // repeat_interleave
// n_head
for (int h = ih0; h < ih1; ++h) {
// ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
- const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
+ const float dt_soft_plus = ggml_softplus(dt[h]);
const int g = h / (nh / ng); // repeat_interleave
// dim
{
ggml_compute_forward_exp(params, dst);
} break;
+ case GGML_UNARY_OP_XIELU:
+ {
+ ggml_compute_forward_xielu(params, dst);
+ } break;
default:
{
GGML_ABORT("fatal error");
return sqrtf(x);
}
+static inline float op_xielu(float x, float alpha_n, float alpha_p, float beta, float eps) {
+ if (x > 0.0f) {
+ return alpha_p * x * x + beta * x;
+ } else {
+ const float min_x_eps = fminf(x, eps);
+ return (expm1f(min_x_eps) - x) * alpha_n + beta * x;
+ }
+}
+
static inline float op_sin(float x) {
return sinf(x);
}
}
}
+template <float (*op)(float, ggml_tensor *)>
+static void unary_op_params(const ggml_compute_params * params, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+
+ /* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32
+ apply_unary_op<op, float, float>(params, dst);
+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16
+ apply_unary_op<op, ggml_fp16_t, ggml_fp16_t>(params, dst);
+ } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16
+ apply_unary_op<op, ggml_bf16_t, ggml_bf16_t>(params, dst);
+ } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) {
+ apply_unary_op<op, ggml_bf16_t, float>(params, dst);
+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
+ apply_unary_op<op, ggml_fp16_t, float>(params, dst);
+ } else {
+ fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__,
+ ggml_type_name(dst->type), ggml_type_name(src0->type));
+ GGML_ABORT("fatal error");
+ }
+}
+
+// Extend vec_unary_op to support functors
+template <typename Op, typename src0_t, typename dst_t>
+static inline void vec_unary_op_functor(int64_t n, dst_t * y, const src0_t * x, Op op) {
+ constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
+ constexpr auto f32_to_dst = type_conversion_table<dst_t >::from_f32;
+
+ for (int i = 0; i < n; i++) {
+ y[i] = f32_to_dst(op(src0_to_f32(x[i])));
+ }
+}
+
+// Extend apply_unary_op to support functors
+template <typename Op, typename src0_t, typename dst_t>
+static void apply_unary_op_functor(const ggml_compute_params * params, ggml_tensor * dst, Op op) {
+ const ggml_tensor * src0 = dst->src[0];
+
+ GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst));
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ GGML_ASSERT( nb0 == sizeof(dst_t));
+ GGML_ASSERT(nb00 == sizeof(src0_t));
+
+ const auto [ir0, ir1] = get_thread_range(params, src0);
+
+ for (int64_t ir = ir0; ir < ir1; ++ir) {
+ const int64_t i03 = ir/(ne02*ne01);
+ const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+ const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+ dst_t * dst_ptr = (dst_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
+ const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+
+ vec_unary_op_functor(ne0, dst_ptr, src0_ptr, op);
+ }
+}
+
+// Generic dispatcher for functors
+template <typename Op>
+static void unary_op_functor(const ggml_compute_params * params, ggml_tensor * dst, Op op) {
+ const ggml_tensor * src0 = dst->src[0];
+
+ /* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32
+ apply_unary_op_functor<Op, float, float>(params, dst, op);
+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16
+ apply_unary_op_functor<Op, ggml_fp16_t, ggml_fp16_t>(params, dst, op);
+ } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16
+ apply_unary_op_functor<Op, ggml_bf16_t, ggml_bf16_t>(params, dst, op);
+ } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) {
+ apply_unary_op_functor<Op, ggml_bf16_t, float>(params, dst, op);
+ } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
+ apply_unary_op_functor<Op, ggml_fp16_t, float>(params, dst, op);
+ } else {
+ fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__,
+ ggml_type_name(dst->type), ggml_type_name(src0->type));
+ GGML_ABORT("fatal error");
+ }
+}
+
void ggml_compute_forward_abs(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_abs>(params, dst);
}
void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor * dst) {
unary_op<op_log>(params, dst);
}
+
+void ggml_compute_forward_xielu(const ggml_compute_params * params, ggml_tensor * dst) {
+ const float alpha_n = ggml_get_op_params_f32(dst, 1);
+ const float alpha_p = ggml_get_op_params_f32(dst, 2);
+ const float beta = ggml_get_op_params_f32(dst, 3);
+ const float eps = ggml_get_op_params_f32(dst, 4);
+
+ const auto xielu_op_params = [alpha_n, alpha_p, beta, eps](float f) {
+ return op_xielu(f, alpha_n, alpha_p, beta, eps);
+ };
+
+ unary_op_functor(params, dst, xielu_op_params);
+}
+
void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_xielu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
#ifdef __cplusplus
}
case GGML_UNARY_OP_ELU:
ggml_cuda_op_elu(ctx, dst);
break;
+ case GGML_UNARY_OP_XIELU:
+ ggml_cuda_op_xielu(ctx, dst);
+ break;
default:
return false;
}
#include "unary.cuh"
+#include "convert.cuh"
static __device__ __forceinline__ float op_abs(float x) {
return fabsf(x);
swiglu_oai_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream);
}
+/* CUDA kernel + launcher for xIELU */
+
+template <typename T>
+static __global__ void xielu_kernel(const T * x, T * dst, const int k, float alpha_n, float alpha_p, float beta, float eps) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+
+ const float xi = ggml_cuda_cast<float>(x[i]);
+
+ const float gate_pos = (xi > 0.0f);
+ const float y_pos = alpha_p * xi * xi + beta * xi;
+ const float min_v_eps = fminf(xi, eps);
+ const float y_neg = (expm1f(min_v_eps) - xi) * alpha_n + beta * xi;
+ const float out = gate_pos * y_pos + (1.0f - gate_pos) * y_neg;
+
+ dst[i] = ggml_cuda_cast<T>(out);
+}
+
+template <typename T>
+static void xielu_cuda(const T * x, T * dst, const int k, float alpha_n, float alpha_p, float beta, float eps, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_XIELU_BLOCK_SIZE) / CUDA_XIELU_BLOCK_SIZE;
+ xielu_kernel<<<num_blocks, CUDA_XIELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, alpha_n, alpha_p, beta, eps);
+}
+
+void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const void * src0_d = src0->data;
+ void * dst_d = dst->data;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
+ GGML_ASSERT(src0->type == dst->type);
+
+ const float alpha_n = ggml_get_op_params_f32(dst, 1);
+ const float alpha_p = ggml_get_op_params_f32(dst, 2);
+ const float beta = ggml_get_op_params_f32(dst, 3);
+ const float eps = ggml_get_op_params_f32(dst, 4);
+
+ if (src0->type == GGML_TYPE_F16) {
+ xielu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), alpha_n, alpha_p, beta, eps, stream);
+ } else {
+ xielu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), alpha_n, alpha_p, beta, eps, stream);
+ }
+}
+
+
+
/* silu_back */
static __device__ __forceinline__ float op_silu_back(float grad, float x) {
#define CUDA_SIN_BLOCK_SIZE 256
#define CUDA_COS_BLOCK_SIZE 256
#define CUDA_GLU_BLOCK_SIZE 256
+#define CUDA_XIELU_BLOCK_SIZE 256
void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
}
}
+static inline float ggml_softplus(float input) {
+ return (input > 20.0f) ? input : logf(1 + expf(input));
+}
//
// logging
//
"HARDSIGMOID",
"EXP",
"GELU_ERF",
+ "XIELU",
};
-static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15");
-
+static_assert(GGML_UNARY_OP_COUNT == 16, "GGML_UNARY_OP_COUNT != 16");
static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
"REGLU",
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SILU);
}
+// ggml_xielu
+
+struct ggml_tensor * ggml_xielu(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ float alpha_n,
+ float alpha_p,
+ float beta,
+ float eps) {
+ struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+
+ ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_XIELU);
+ ggml_set_op_params_f32(result, 1, beta + ggml_softplus(alpha_n));
+ ggml_set_op_params_f32(result, 2, ggml_softplus(alpha_p));
+ ggml_set_op_params_f32(result, 3, beta);
+ ggml_set_op_params_f32(result, 4, eps);
+
+ result->op = GGML_OP_UNARY;
+ result->src[0] = a;
+
+ return result;
+}
+
// ggml_silu_back
struct ggml_tensor * ggml_silu_back(