]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
model : Apertus model implementation (llama/15852)
authorPiotr Wilkin (ilintar) <redacted>
Thu, 2 Oct 2025 17:43:22 +0000 (19:43 +0200)
committerGeorgi Gerganov <redacted>
Sun, 12 Oct 2025 08:16:23 +0000 (11:16 +0300)
* First attempt

* No permute during convert (fixes qk tensors), proper norm application.

* RoPE = NeoX

* Coherence!

* Migrate xielu params from tensors to hyperparameters

* Simple CUDA kernel

* Revert stupid LLM refactorings

* Chat template support

* configchecker / flake8 errors

* Reorder unary.cu

* I do conclude that LLMs are, in fact, stupid.

* Fix after merge

* Final newline

* Make xIELU an UNARY_OP

* Final newline

* Correctly account for parameter shift

* Argh.

* Update ggml/src/ggml-cpu/unary-ops.cpp

Co-authored-by: Georgi Gerganov <redacted>
* Refactor: remove unused methods, inline and factorize softplus, add const modifiers

* Revert CUDA changes, implement xIELU as a separate OP

* Pesky newline

* Add float2half / half2float for F16 inputs/outputs

* CUDA variants, attempt 2

* Actually, attempt 3

* Update ggml/src/ggml-cuda/unary.cu

Co-authored-by: Johannes Gäßler <redacted>
* Missing convert header

* Proper formula and reference for xIELU in the comments.

* Modify unary-ops.cpp to add the functor-based logic besides the template system to retain optimizations

* Apply suggestions from code review

Co-authored-by: Sigbjørn Skjæret <redacted>
* Add tensor mappings for Apertus to global list instead

* Fix lazy on scalars

* Update ggml/src/ggml-cuda/unary.cu

Co-authored-by: Johannes Gäßler <redacted>
* Add comment about the constraints on positive/negative alpha

* Change `softplus` to `ggml_softplus`

---------

Co-authored-by: Georgi Gerganov <redacted>
Co-authored-by: Johannes Gäßler <redacted>
Co-authored-by: Sigbjørn Skjæret <redacted>
ggml/include/ggml.h
ggml/src/ggml-cpu/ggml-cpu.c
ggml/src/ggml-cpu/ops.cpp
ggml/src/ggml-cpu/unary-ops.cpp
ggml/src/ggml-cpu/unary-ops.h
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/unary.cu
ggml/src/ggml-cuda/unary.cuh
ggml/src/ggml-impl.h
ggml/src/ggml.c

index 5028a9cebf260b657f199fe80e3e2c7dcb630cd3..f65eb75e298fa99dc655f3852606825ea2c81a06 100644 (file)
@@ -576,6 +576,7 @@ extern "C" {
         GGML_UNARY_OP_HARDSIGMOID,
         GGML_UNARY_OP_EXP,
         GGML_UNARY_OP_GELU_ERF,
+        GGML_UNARY_OP_XIELU,
 
         GGML_UNARY_OP_COUNT,
     };
@@ -1150,6 +1151,18 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    // xIELU activation function
+    // x = x * (c_a(alpha_n) + c_b(alpha_p, beta) * sigmoid(beta * x)) + eps * (x > 0)
+    // where c_a = softplus and c_b(a, b) = softplus(a) + b are constraining functions
+    // that constrain the positive and negative source alpha values respectively
+    GGML_API struct ggml_tensor * ggml_xielu(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            float alpha_n,
+            float alpha_p,
+            float beta,
+            float eps);
+
     // gated linear unit ops
     // A: n columns, r rows,
     // result is n / 2 columns, r rows,
index dbc07301b296ed589056819a0ef884d6e60f6187..eded6eb77ed69e64b9c075b94c96c080a997e5f7 100644 (file)
@@ -2187,6 +2187,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
                 case GGML_UNARY_OP_GELU_ERF:
                 case GGML_UNARY_OP_GELU_QUICK:
                 case GGML_UNARY_OP_SILU:
+                case GGML_UNARY_OP_XIELU:
                     {
                         n_tasks = n_threads;
                     } break;
index 14f7dcf4f41ad244f013fd18220461767c1f6228..6275c8305a971336a0eda39daacac61a75530382 100644 (file)
@@ -8637,7 +8637,7 @@ static void ggml_compute_forward_ssm_scan_f32(
                 // n_head
                 for (int h = ih0; h < ih1; ++h) {
                     // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
-                    const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
+                    const float dt_soft_plus = ggml_softplus(dt[h]);
                     const float dA = expf(dt_soft_plus * A[h]);
                     const int g = h / (nh / ng); // repeat_interleave
 
@@ -8734,7 +8734,7 @@ static void ggml_compute_forward_ssm_scan_f32(
                 // n_head
                 for (int h = ih0; h < ih1; ++h) {
                     // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16
-                    const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h];
+                    const float dt_soft_plus = ggml_softplus(dt[h]);
                     const int g = h / (nh / ng); // repeat_interleave
 
                     // dim
@@ -8997,6 +8997,10 @@ void ggml_compute_forward_unary(
             {
                 ggml_compute_forward_exp(params, dst);
             } break;
+        case GGML_UNARY_OP_XIELU:
+            {
+                ggml_compute_forward_xielu(params, dst);
+            } break;
         default:
             {
                 GGML_ABORT("fatal error");
index 4fce569b3bfc89e976539113cbf72ea933967a35..cf1a4615d042cd27001f3e3450cd3aa47c22ce53 100644 (file)
@@ -52,6 +52,15 @@ static inline float op_sqrt(float x) {
     return sqrtf(x);
 }
 
+static inline float op_xielu(float x, float alpha_n, float alpha_p, float beta, float eps) {
+    if (x > 0.0f) {
+        return alpha_p * x * x + beta * x;
+    } else {
+        const float min_x_eps = fminf(x, eps);
+        return (expm1f(min_x_eps) - x) * alpha_n + beta * x;
+    }
+}
+
 static inline float op_sin(float x) {
     return sinf(x);
 }
@@ -121,6 +130,86 @@ static void unary_op(const ggml_compute_params * params, ggml_tensor * dst) {
     }
 }
 
+template <float (*op)(float, ggml_tensor *)>
+static void unary_op_params(const ggml_compute_params * params, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+
+    /*  */ if (src0->type == GGML_TYPE_F32  && dst->type == GGML_TYPE_F32) { // all f32
+        apply_unary_op<op, float, float>(params, dst);
+    } else if (src0->type == GGML_TYPE_F16  && dst->type == GGML_TYPE_F16) { // all f16
+        apply_unary_op<op, ggml_fp16_t, ggml_fp16_t>(params, dst);
+    } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16
+        apply_unary_op<op, ggml_bf16_t, ggml_bf16_t>(params, dst);
+    } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) {
+        apply_unary_op<op, ggml_bf16_t, float>(params, dst);
+    } else if (src0->type == GGML_TYPE_F16  && dst->type == GGML_TYPE_F32) {
+        apply_unary_op<op, ggml_fp16_t, float>(params, dst);
+    } else {
+        fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__,
+            ggml_type_name(dst->type), ggml_type_name(src0->type));
+        GGML_ABORT("fatal error");
+    }
+}
+
+// Extend vec_unary_op to support functors
+template <typename Op, typename src0_t, typename dst_t>
+static inline void vec_unary_op_functor(int64_t n, dst_t * y, const src0_t * x, Op op) {
+    constexpr auto src0_to_f32 = type_conversion_table<src0_t>::to_f32;
+    constexpr auto f32_to_dst  = type_conversion_table<dst_t >::from_f32;
+
+    for (int i = 0; i < n; i++) {
+        y[i] = f32_to_dst(op(src0_to_f32(x[i])));
+    }
+}
+
+// Extend apply_unary_op to support functors
+template <typename Op, typename src0_t, typename dst_t>
+static void apply_unary_op_functor(const ggml_compute_params * params, ggml_tensor * dst, Op op) {
+    const ggml_tensor * src0 = dst->src[0];
+
+    GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst));
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    GGML_ASSERT( nb0 == sizeof(dst_t));
+    GGML_ASSERT(nb00 == sizeof(src0_t));
+
+    const auto [ir0, ir1] = get_thread_range(params, src0);
+
+    for (int64_t ir = ir0; ir < ir1; ++ir) {
+        const int64_t i03 = ir/(ne02*ne01);
+        const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
+        const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
+
+        dst_t        * dst_ptr  = (dst_t  *)       ((char *)       dst->data  + i03*nb3  + i02*nb2  + i01*nb1 );
+        const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
+
+        vec_unary_op_functor(ne0, dst_ptr, src0_ptr, op);
+    }
+}
+
+// Generic dispatcher for functors
+template <typename Op>
+static void unary_op_functor(const ggml_compute_params * params, ggml_tensor * dst, Op op) {
+    const ggml_tensor * src0 = dst->src[0];
+
+    /*  */ if (src0->type == GGML_TYPE_F32  && dst->type == GGML_TYPE_F32) { // all f32
+        apply_unary_op_functor<Op, float, float>(params, dst, op);
+    } else if (src0->type == GGML_TYPE_F16  && dst->type == GGML_TYPE_F16) { // all f16
+        apply_unary_op_functor<Op, ggml_fp16_t, ggml_fp16_t>(params, dst, op);
+    } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16
+        apply_unary_op_functor<Op, ggml_bf16_t, ggml_bf16_t>(params, dst, op);
+    } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) {
+        apply_unary_op_functor<Op, ggml_bf16_t, float>(params, dst, op);
+    } else if (src0->type == GGML_TYPE_F16  && dst->type == GGML_TYPE_F32) {
+        apply_unary_op_functor<Op, ggml_fp16_t, float>(params, dst, op);
+    } else {
+        fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__,
+            ggml_type_name(dst->type), ggml_type_name(src0->type));
+        GGML_ABORT("fatal error");
+    }
+}
+
 void ggml_compute_forward_abs(const ggml_compute_params * params, ggml_tensor * dst) {
     unary_op<op_abs>(params, dst);
 }
@@ -184,3 +273,17 @@ void ggml_compute_forward_cos(const ggml_compute_params * params, ggml_tensor *
 void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor * dst) {
     unary_op<op_log>(params, dst);
 }
+
+void ggml_compute_forward_xielu(const ggml_compute_params * params, ggml_tensor * dst) {
+    const float alpha_n = ggml_get_op_params_f32(dst, 1);
+    const float alpha_p = ggml_get_op_params_f32(dst, 2);
+    const float beta = ggml_get_op_params_f32(dst, 3);
+    const float eps = ggml_get_op_params_f32(dst, 4);
+
+    const auto xielu_op_params = [alpha_n, alpha_p, beta, eps](float f) {
+        return op_xielu(f, alpha_n, alpha_p, beta, eps);
+    };
+
+    unary_op_functor(params, dst, xielu_op_params);
+}
+
index b1ade2c8e341fead4218862fd4bf63b2f98d191e..697c1e0da0ace50739973cd1bc91e430273bff7d 100644 (file)
@@ -22,6 +22,7 @@ void ggml_compute_forward_sqrt(const struct ggml_compute_params * params, struct
 void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_xielu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 
 #ifdef __cplusplus
 }
index b7e81b21bcb2844ca7f016b37b889599ca3b64ee..26e72bbc2b94209da7a71484302afde3368474f1 100644 (file)
@@ -2334,6 +2334,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
                 case GGML_UNARY_OP_ELU:
                     ggml_cuda_op_elu(ctx, dst);
                     break;
+                case GGML_UNARY_OP_XIELU:
+                    ggml_cuda_op_xielu(ctx, dst);
+                    break;
                 default:
                     return false;
             }
index 5aff8a876af2ce1f73f603f2ab81854e22f1a9de..3c564566a51ff8a76ba66a16d90ecba7dd606037 100644 (file)
@@ -1,4 +1,5 @@
 #include "unary.cuh"
+#include "convert.cuh"
 
 static __device__ __forceinline__ float op_abs(float x) {
     return fabsf(x);
@@ -375,6 +376,59 @@ void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
     swiglu_oai_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream);
 }
 
+/* CUDA kernel + launcher for xIELU */
+
+template <typename T>
+static __global__ void xielu_kernel(const T * x, T * dst, const int k, float alpha_n, float alpha_p, float beta, float eps) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+
+    const float xi = ggml_cuda_cast<float>(x[i]);
+
+    const float gate_pos = (xi > 0.0f);
+    const float y_pos = alpha_p * xi * xi + beta * xi;
+    const float min_v_eps = fminf(xi, eps);
+    const float y_neg = (expm1f(min_v_eps) - xi) * alpha_n + beta * xi;
+    const float out = gate_pos * y_pos + (1.0f - gate_pos) * y_neg;
+
+    dst[i] = ggml_cuda_cast<T>(out);
+}
+
+template <typename T>
+static void xielu_cuda(const T * x, T * dst, const int k, float alpha_n, float alpha_p, float beta, float eps, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_XIELU_BLOCK_SIZE) / CUDA_XIELU_BLOCK_SIZE;
+    xielu_kernel<<<num_blocks, CUDA_XIELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, alpha_n, alpha_p, beta, eps);
+}
+
+void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const void * src0_d = src0->data;
+    void * dst_d = dst->data;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous(src0));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32 ||  dst->type == GGML_TYPE_F16);
+    GGML_ASSERT(src0->type == dst->type);
+
+    const float alpha_n = ggml_get_op_params_f32(dst, 1);
+    const float alpha_p = ggml_get_op_params_f32(dst, 2);
+    const float beta    = ggml_get_op_params_f32(dst, 3);
+    const float eps     = ggml_get_op_params_f32(dst, 4);
+
+    if (src0->type == GGML_TYPE_F16) {
+        xielu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), alpha_n, alpha_p, beta, eps, stream);
+    } else {
+        xielu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), alpha_n, alpha_p, beta, eps, stream);
+    }
+}
+
+
+
 /* silu_back */
 
 static __device__ __forceinline__ float op_silu_back(float grad, float x) {
index da3caf1d8962e35f23fa9802d88f9e20050ecb39..8e7644fcd9a484492283ffaa50113b54e09d61dd 100644 (file)
@@ -16,6 +16,7 @@
 #define CUDA_SIN_BLOCK_SIZE 256
 #define CUDA_COS_BLOCK_SIZE 256
 #define CUDA_GLU_BLOCK_SIZE 256
+#define CUDA_XIELU_BLOCK_SIZE 256
 
 void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
@@ -72,3 +73,5 @@ void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
 void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
+void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 86a1ebf62b70d408ac0cde6d087d20fc122ddf09..d0fb3bccad2250d99d97f70484a0577d5aef4176 100644 (file)
@@ -102,6 +102,9 @@ static bool ggml_op_is_empty(enum ggml_op op) {
     }
 }
 
+static inline float ggml_softplus(float input) {
+    return (input > 20.0f) ? input : logf(1 + expf(input));
+}
 //
 // logging
 //
index aecbdad5a3d2578852445e64b865a11d69f2f73a..7d50b42a37d45d4508bfb261e619a74d391d0bbf 100644 (file)
@@ -1143,10 +1143,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
     "HARDSIGMOID",
     "EXP",
     "GELU_ERF",
+    "XIELU",
 };
 
-static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15");
-
+static_assert(GGML_UNARY_OP_COUNT == 16, "GGML_UNARY_OP_COUNT != 16");
 
 static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
     "REGLU",
@@ -2652,6 +2652,29 @@ struct ggml_tensor * ggml_silu_inplace(
     return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SILU);
 }
 
+// ggml_xielu
+
+struct ggml_tensor * ggml_xielu(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        float alpha_n,
+        float alpha_p,
+        float beta,
+        float eps) {
+    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+
+    ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_XIELU);
+    ggml_set_op_params_f32(result, 1, beta + ggml_softplus(alpha_n));
+    ggml_set_op_params_f32(result, 2, ggml_softplus(alpha_p));
+    ggml_set_op_params_f32(result, 3, beta);
+    ggml_set_op_params_f32(result, 4, eps);
+
+    result->op     = GGML_OP_UNARY;
+    result->src[0] = a;
+
+    return result;
+}
+
 // ggml_silu_back
 
 struct ggml_tensor * ggml_silu_back(