]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CANN: Implement GLU ops (llama/14884)
authorhipudding <redacted>
Sat, 26 Jul 2025 09:56:18 +0000 (17:56 +0800)
committerGeorgi Gerganov <redacted>
Mon, 28 Jul 2025 05:43:21 +0000 (08:43 +0300)
Implement REGLU, GEGLU, SWIGLU ops according to #14158

src/ggml-cann/acl_tensor.cpp
src/ggml-cann/aclnn_ops.cpp
src/ggml-cann/aclnn_ops.h
src/ggml-cann/ggml-cann.cpp

index f311864d486f7936843499b6ed7ae75e7a5d0dce..8ffac31dd661ae82c6ca8a25dd40b90bb4de4d7a 100755 (executable)
@@ -77,6 +77,8 @@ aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne,
     for (int i = 0; i < final_dims; i++) {
         acl_storage_len += (acl_ne[i] - 1) * acl_stride[i];
     }
+    size_t elem_offset = offset / ggml_element_size(tensor);
+    acl_storage_len += elem_offset;
 
     // Reverse ne and stride.
     std::reverse(acl_ne, acl_ne + final_dims);
@@ -84,7 +86,7 @@ aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne,
 
     aclTensor* acl_tensor = aclCreateTensor(
         acl_ne, final_dims, ggml_cann_type_mapping(tensor->type), acl_stride,
-        offset / ggml_element_size(tensor), format, &acl_storage_len, 1,
+        elem_offset, format, &acl_storage_len, 1,
         tensor->data);
 
     return acl_tensor;
index 76bed4e8cd0fc774e4e617e535dc0369f8656299..d616c491ae9fe59bf911649cad7540d492de664d 100755 (executable)
@@ -99,7 +99,7 @@ void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, aclT
     }
 }
 
-void ggml_cann_unary_op(
+void ggml_cann_op_unary(
     std::function<void(ggml_backend_cann_context&, aclTensor*, aclTensor*)> unary_op,
     ggml_backend_cann_context& ctx, ggml_tensor* dst) {
     ggml_tensor* src = dst->src[0];
@@ -111,6 +111,42 @@ void ggml_cann_unary_op(
     ggml_cann_release_resources(ctx, acl_src, acl_dst);
 }
 
+void ggml_cann_op_unary_gated(
+    std::function<void(ggml_backend_cann_context&, aclTensor*, aclTensor*)> unary_op,
+    ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+    ggml_tensor* src0 = dst->src[0];
+    ggml_tensor* src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_is_contiguous_1(src0));
+    GGML_ASSERT(ggml_is_contiguous_1(dst));
+    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
+
+    aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+    aclTensor *acl_src0 = nullptr, *acl_src1 = nullptr;
+    if(src1) {
+        GGML_ASSERT(ggml_is_contiguous_1(src1));
+        GGML_ASSERT(src0->type == src1->type);
+
+        acl_src0 = ggml_cann_create_tensor(src0);
+        acl_src1 = ggml_cann_create_tensor(src1);
+    } else {
+        int64_t ne[] = {src0->ne[0] / 2, src0->ne[1], src0->ne[2], src0->ne[3]};
+        size_t nb[] = {src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]};
+        acl_src0 = ggml_cann_create_tensor(src0, ne, nb, GGML_MAX_DIMS, ACL_FORMAT_ND, 0);
+        acl_src1 = ggml_cann_create_tensor(src0, ne, nb, GGML_MAX_DIMS, ACL_FORMAT_ND, ne[0] * ggml_element_size(src0));
+        if (swapped) {
+            std::swap(acl_src0, acl_src1);
+        }
+    }
+
+    unary_op(ctx, acl_src0, acl_dst);
+    GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, acl_dst, acl_src1);
+
+    ggml_cann_release_resources(ctx, acl_src0, acl_dst);
+    if(src1)
+        ggml_cann_release_resources(ctx, acl_src1);
+}
+
 /**
  * @brief Repeats elements of a tensor along each dimension according to the
  * specified repeat array.
index 924da66ed6862fc6f1af0c7a2cfc12ec5b3e8382..8deaf7ea1db1b7e362863cfb272e8c696af981a4 100755 (executable)
@@ -1098,7 +1098,7 @@ void ggml_cann_binary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
  * @param dst The destination tensor. Its src[0] is treated as the input tensor.
  */
 template <void unary_op(ggml_backend_cann_context&, aclTensor*, aclTensor*)>
-    void ggml_cann_unary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+    void ggml_cann_op_unary(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
     ggml_tensor* src = dst->src[0];
 
     aclTensor* acl_src = ggml_cann_create_tensor(src);
@@ -1109,49 +1109,125 @@ template <void unary_op(ggml_backend_cann_context&, aclTensor*, aclTensor*)>
 }
 
 /**
- * @brief   Applies a unary operation to a ggml tensor using the CANN backend.
+ * @brief Applies a unary operation to a ggml tensor using the CANN backend.
  *
- * @details This function performs a unary operation on the input tensor using
- * a user-provided lambda or callable object `unary_op`, which accepts the CANN
- * context and two ACL tensors (source and destination). Internally, this function
- * creates ACL representations of the ggml tensors and invokes the unary operation.
- * The result is stored in the destination tensor `dst`. This utility abstracts the
- * common boilerplate of tensor conversion and cleanup when implementing unary ops.
+ * @details This function applies a unary operation to the input tensor using
+ * a user-provided lambda or callable `unary_op`. The lambda receives the
+ * CANN backend context and two ACL tensors: the source and the destination.
  *
- * @param unary_op A callable that performs the unary operation using CANN APIs.
- * @param ctx The CANN context used for operations.
- * @param dst The destination tensor where the result will be stored.
- *            The source tensor is retrieved from `dst->src[0]`.
+ * Internally, this function handles the conversion from GGML tensors to ACL tensors,
+ * calls the provided unary op, and manages resource cleanup. The input is assumed
+ * to be `dst->src[0]`, and the result is written to `dst`.
+ *
+ * This utility simplifies writing unary op wrappers by abstracting tensor preparation.
+ *
+ * @param unary_op A callable that performs the unary operation using CANN ACL APIs.
+ * @param ctx The CANN context for operation execution.
+ * @param dst The destination ggml_tensor where the result will be stored.
+ *            The input tensor is assumed to be `dst->src[0]`.
+ *
+ * @see GGML_CANN_CALL_OP_UNARY
  */
-void ggml_cann_unary_op(
+void ggml_cann_op_unary(
     std::function<void(ggml_backend_cann_context&, aclTensor*, aclTensor*)> unary_op,
     ggml_backend_cann_context& ctx, ggml_tensor* dst);
 
 /**
- * @brief Helper macro to invoke a unary ACL operation using ggml_cann_unary_op.
+ * @brief Applies a gated (GLU-style) unary operation using the CANN backend.
+ *
+ * @details This function performs a gated activation such as GEGLU or ReGLU.
+ * It supports two input modes:
+ *
+ * 1. **Dual input mode**: `dst->src[0]` and `dst->src[1]` are both valid tensors.
+ *    These are used directly as the value and gate tensors.
+ *
+ * 2. **Packed input mode**: Only `dst->src[0]` is valid, and it is assumed to
+ *    contain a concatenation of value and gate along the first dimension. This tensor
+ *    will be split into two equal halves to form the value and gate inputs.
+ *
+ * The function applies a user-provided unary operation (e.g., GELU) to the value tensor,
+ * then multiplies the result in-place with the gate tensor:
+ *
+ * @code
+ * dst = unary_op(value) * gate;
+ * @endcode
+ *
+ * The `swapped` parameter (from `dst->op_params[1]`) allows flipping the
+ * order of value/gate in the packed input case.
+ *
+ * @param unary_op A callable that performs the unary operation using CANN ACL APIs.
+ *                 It receives (ctx, acl_value_tensor, acl_output_tensor).
+ * @param ctx      The CANN context used for execution.
+ * @param dst      The destination ggml_tensor. Source tensors are in `dst->src[0]` and optionally `src[1]`.
+ *
+ * @see GGML_CANN_CALL_OP_UNARY_GATED
+ */
+void ggml_cann_op_unary_gated(
+    std::function<void(ggml_backend_cann_context&, aclTensor*, aclTensor*)> unary_op,
+    ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief Helper macro to call a unary ACL operator via ggml_cann_op_unary.
+ *
+ * This macro wraps the specified ACLNN unary operator name into a lambda expression,
+ * and passes it to `ggml_cann_op_unary`, which handles the common logic for executing
+ * unary ops in the CANN backend.
+ *
+ * Internally, this macro expands to a lambda like:
+ * @code
+ * [](ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst) {
+ *     GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);
+ * };
+ * @endcode
+ *
+ * This lambda is then passed to `ggml_cann_op_unary`, which applies the operation.
+ *
+ * @param OP_NAME The name of the ACL unary operator to invoke via GGML_CANN_CALL_ACLNN_OP.
+ *
+ * @see ggml_cann_op_unary
+ * @see GGML_CANN_CALL_ACLNN_OP
+ */
+#define GGML_CANN_CALL_OP_UNARY(OP_NAME)                              \
+    do {                                                              \
+        auto lambda = [](ggml_backend_cann_context& ctx,              \
+            aclTensor* acl_src,                                       \
+            aclTensor* acl_dst) {                                     \
+            GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);  \
+        };                                                            \
+        ggml_cann_op_unary(lambda, ctx, dst);                         \
+    }                                                                 \
+    while (0)
+
+/**
+ * @brief Helper macro to call a gated unary ACL operator via ggml_cann_op_unary_gated.
  *
- * This macro defines an inline lambda wrapping a specific ACL operation name,
- * and passes it to the templated ggml_cann_unary_op function. It simplifies
- * calling unary ops by hiding the lambda boilerplate.
+ * This macro wraps the specified ACLNN unary operator name into a lambda expression,
+ * and passes it to `ggml_cann_op_unary_gated`, which handles the common logic for
+ * executing gated unary ops in the CANN backend.
  *
- * Internally, the lambda will call:
+ * Internally, this macro expands to a lambda like:
  * @code
- * GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);
+ * [](ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst) {
+ *     GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);
+ * };
  * @endcode
  *
+ * This lambda is then passed to `ggml_cann_op_unary_gated`, which applies the operation.
+ *
  * @param OP_NAME The name of the ACL unary operator to invoke via GGML_CANN_CALL_ACLNN_OP.
  *
- * @see ggml_cann_unary_op
+ * @see ggml_cann_op_unary_gated
  * @see GGML_CANN_CALL_ACLNN_OP
  */
-#define GGML_CANN_CALL_UNARY_OP(OP_NAME)                              \
+#define GGML_CANN_CALL_OP_UNARY_GATED(OP_NAME)                        \
     do {                                                              \
         auto lambda = [](ggml_backend_cann_context& ctx,              \
             aclTensor* acl_src,                                       \
             aclTensor* acl_dst) {                                     \
             GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst);  \
         };                                                            \
-        ggml_cann_unary_op(lambda, ctx, dst);                         \
+        ggml_cann_op_unary_gated(lambda, ctx, dst);                   \
     }                                                                 \
     while (0)
+
 #endif  // CANN_ACLNN_OPS
index f30241aca4046a13546455b8675d6919d3cd3d84..c6edb6b61bbbffb087a782a9ba27f7abba171672 100755 (executable)
@@ -1681,16 +1681,18 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
         case GGML_OP_UNARY:
             switch (ggml_get_unary_op(dst)) {
                 case GGML_UNARY_OP_ABS:
-                    GGML_CANN_CALL_UNARY_OP(Abs);
+                    GGML_CANN_CALL_OP_UNARY(Abs);
                     break;
                 case GGML_UNARY_OP_NEG:
-                    GGML_CANN_CALL_UNARY_OP(Neg);
+                    GGML_CANN_CALL_OP_UNARY(Neg);
                     break;
                 case GGML_UNARY_OP_GELU:
-                    GGML_CANN_CALL_UNARY_OP(Gelu);
+                case GGML_UNARY_OP_GELU_ERF:
+                    // aclnnGelu internally uses the erf-based approximation.
+                    GGML_CANN_CALL_OP_UNARY(Gelu);
                     break;
                 case GGML_UNARY_OP_SILU:
-                    GGML_CANN_CALL_UNARY_OP(Silu);
+                    GGML_CANN_CALL_OP_UNARY(Silu);
                     break;
                 case GGML_UNARY_OP_GELU_QUICK: {
                     auto lambda = [](ggml_backend_cann_context& ctx,
@@ -1698,31 +1700,31 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
                         aclTensor* acl_dst) {
                         GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
                     };
-                    ggml_cann_unary_op(lambda, ctx, dst);
+                    ggml_cann_op_unary(lambda, ctx, dst);
                 } break;
                 case GGML_UNARY_OP_TANH:
-                    GGML_CANN_CALL_UNARY_OP(Tanh);
+                    GGML_CANN_CALL_OP_UNARY(Tanh);
                     break;
                 case GGML_UNARY_OP_RELU:
-                    GGML_CANN_CALL_UNARY_OP(Relu);
+                    GGML_CANN_CALL_OP_UNARY(Relu);
                     break;
                 case GGML_UNARY_OP_SIGMOID:
-                    GGML_CANN_CALL_UNARY_OP(Sigmoid);
+                    GGML_CANN_CALL_OP_UNARY(Sigmoid);
                     break;
                 case GGML_UNARY_OP_HARDSIGMOID:
-                    GGML_CANN_CALL_UNARY_OP(Hardsigmoid);
+                    GGML_CANN_CALL_OP_UNARY(Hardsigmoid);
                     break;
                 case GGML_UNARY_OP_HARDSWISH:
-                    GGML_CANN_CALL_UNARY_OP(Hardswish);
+                    GGML_CANN_CALL_OP_UNARY(Hardswish);
                     break;
                 case GGML_UNARY_OP_EXP:
-                    GGML_CANN_CALL_UNARY_OP(Exp);
+                    GGML_CANN_CALL_OP_UNARY(Exp);
                     break;
                 case GGML_UNARY_OP_ELU:
                     ggml_cann_elu(ctx, dst);
                     break;
                 case GGML_UNARY_OP_SGN:
-                    GGML_CANN_CALL_UNARY_OP(Sign);
+                    GGML_CANN_CALL_OP_UNARY(Sign);
                     break;
                 case GGML_UNARY_OP_STEP:
                     ggml_cann_step(ctx, dst);
@@ -1731,6 +1733,31 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
                     return false;
             }
             break;
+        case GGML_OP_GLU:
+            switch (ggml_get_glu_op(dst)) {
+                case GGML_GLU_OP_REGLU:
+                    GGML_CANN_CALL_OP_UNARY_GATED(Relu);
+                    break;
+                case GGML_GLU_OP_GEGLU:
+                case GGML_GLU_OP_GEGLU_ERF:
+                    // aclnnGelu internally uses the erf-based approximation.
+                    GGML_CANN_CALL_OP_UNARY_GATED(Gelu);
+                    break;
+                case GGML_GLU_OP_SWIGLU:
+                    GGML_CANN_CALL_OP_UNARY_GATED(Silu);
+                    break;
+                case GGML_GLU_OP_GEGLU_QUICK: {
+                    auto lambda = [](ggml_backend_cann_context& ctx,
+                        aclTensor* acl_src,
+                        aclTensor* acl_dst) {
+                        GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst);
+                    };
+                    ggml_cann_op_unary_gated(lambda, ctx, dst);
+                } break;
+                default:
+                    return false;
+            }
+            break;
         case GGML_OP_NORM:
             ggml_cann_norm(ctx, dst);
             break;
@@ -1773,7 +1800,7 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
             ggml_cann_binary_op<aclnn_mul>(ctx, dst);
             break;
         case GGML_OP_SQRT:
-            GGML_CANN_CALL_UNARY_OP(Sqrt);
+            GGML_CANN_CALL_OP_UNARY(Sqrt);
             break;
         case GGML_OP_CLAMP:
             ggml_cann_clamp(ctx, dst);
@@ -1818,16 +1845,16 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
             ggml_cann_argmax(ctx, dst);
             break;
         case GGML_OP_COS:
-            ggml_cann_unary_op<aclnn_cos>(ctx, dst);
+            ggml_cann_op_unary<aclnn_cos>(ctx, dst);
             break;
         case GGML_OP_SIN:
-            ggml_cann_unary_op<aclnn_sin>(ctx, dst);
+            ggml_cann_op_unary<aclnn_sin>(ctx, dst);
             break;
         case GGML_OP_CONV_TRANSPOSE_1D:
             ggml_cann_conv_transpose_1d(ctx, dst);
             break;
         case GGML_OP_LOG:
-            GGML_CANN_CALL_UNARY_OP(Log);
+            GGML_CANN_CALL_OP_UNARY(Log);
             break;
         case GGML_OP_MEAN:
             ggml_cann_mean(ctx, dst);
@@ -2101,10 +2128,23 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
                 case GGML_UNARY_OP_ELU:
                 case GGML_UNARY_OP_SGN:
                 case GGML_UNARY_OP_STEP:
+                case GGML_UNARY_OP_GELU_ERF:
                     return true;
                 default:
                     return false;
             }
+        case GGML_OP_GLU:
+            switch (ggml_get_glu_op(op)) {
+                case GGML_GLU_OP_REGLU:
+                case GGML_GLU_OP_GEGLU:
+                case GGML_GLU_OP_SWIGLU:
+                case GGML_GLU_OP_GEGLU_ERF:
+                case GGML_GLU_OP_GEGLU_QUICK:
+                    return true;
+                default:
+                    return false;
+            }
+            break;
         case GGML_OP_MUL_MAT: {
             switch (op->src[0]->type) {
                 case GGML_TYPE_F16: