]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CANN: Support operator SIN COS ARGMAX (llama/12709)
authorChenguang Li <redacted>
Thu, 3 Apr 2025 07:18:08 +0000 (15:18 +0800)
committerGeorgi Gerganov <redacted>
Thu, 24 Apr 2025 17:39:16 +0000 (20:39 +0300)
* [CANN]support sin cos argmax

Signed-off-by: noemotiovon <redacted>
* [CANN]codestyle adjustment

Signed-off-by: noemotiovon <redacted>
* [CANN]Remove redundant code

Signed-off-by: noemotiovon <redacted>
---------

Signed-off-by: noemotiovon <redacted>
Co-authored-by: noemotiovon <redacted>
ggml/src/ggml-cann/aclnn_ops.cpp
ggml/src/ggml-cann/aclnn_ops.h
ggml/src/ggml-cann/ggml-cann.cpp

index ae13730c0c32d173013d2570ad1d3f305a4d745a..f5734cbabb6c57c6e8a3f0401b6d5c4359c0f064 100644 (file)
@@ -51,6 +51,7 @@
 #include <aclnnop/aclnn_triu.h>
 #include <aclnnop/aclnn_upsample_nearest_2d.h>
 #include <aclnnop/aclnn_weight_quant_batch_matmul_v2.h>
+#include <aclnnop/aclnn_argmax.h>
 #include <float.h>
 
 #include <cmath>
@@ -3440,3 +3441,46 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
     ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor));
     ACL_CHECK(aclDestroyTensor(acl_dst));
 }
+
+
+ void ggml_cann_argmax(ggml_backend_cann_context& ctx, ggml_tensor* dst){
+    ggml_tensor * src0 = dst->src[0];
+
+    aclTensor* acl_src = ggml_cann_create_tensor(src0);
+    aclTensor* acl_dst = ggml_cann_create_tensor(dst, dst->ne, dst->nb, 3);
+
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    ACL_CHECK(aclnnArgMaxGetWorkspaceSize(acl_src, 3, false, acl_dst,
+                     &workspaceSize, &executor));
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+    ACL_CHECK(aclnnArgMax(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+    ACL_CHECK(aclDestroyTensor(acl_src));
+    ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+void ggml_cann_cos(ggml_backend_cann_context& ctx, ggml_tensor* dst){
+    ggml_tensor * src0 = dst->src[0];
+
+    aclTensor* acl_src = ggml_cann_create_tensor(src0);
+    aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+    aclnn_cos(ctx, acl_src, acl_dst);
+    ACL_CHECK(aclDestroyTensor(acl_src));
+    ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+void ggml_cann_sin(ggml_backend_cann_context& ctx, ggml_tensor* dst){
+    ggml_tensor * src0 = dst->src[0];
+
+    aclTensor* acl_src = ggml_cann_create_tensor(src0);
+    aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+    aclnn_sin(ctx, acl_src, acl_dst);
+    ACL_CHECK(aclDestroyTensor(acl_src));
+    ACL_CHECK(aclDestroyTensor(acl_dst));
+}
index 51a5cf92f016d2771dd63a3d62fb5662639b1aad..1327905032944f73fc4f372b0949eedc9d45e9f8 100644 (file)
@@ -484,6 +484,47 @@ void ggml_cann_mul_mat(ggml_backend_cann_context& ctx, ggml_tensor* dst);
  */
 void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst);
 
+/**
+ * @brief   Computes the index of the maximum value along the specified dimension
+ *          of a ggml tensor using the CANN backend.
+ *
+ * @details This function performs an argmax operation on the input tensor.
+ *          It finds the index of the maximum value along the specified axis
+ *          and stores these indices in the destination tensor `dst`. The
+ *          operation is executed using the CANN backend for optimized performance.
+ *
+ * @param ctx The CANN context used for operations.
+ * @param dst The destination tensor where the indices of the maximum values will be stored.
+ *            dst->op is `GGML_OP_ARGMAX`.
+ */
+void ggml_cann_argmax(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief   Computes the cosine of each element in a ggml tensor using the CANN backend.
+ *
+ * @details This function applies the cosine function element-wise to the input tensor.
+ *          The computed cosine values are stored in the destination tensor `dst`.
+ *          The operation is optimized using the CANN backend for improved performance.
+ *
+ * @param ctx The CANN context used for operations.
+ * @param dst The destination tensor where the cosine values will be stored.
+ *            dst->op is `GGML_OP_COS`.
+ */
+void ggml_cann_cos(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief   Computes the sine of each element in a ggml tensor using the CANN backend.
+ *
+ * @details This function applies the sine function element-wise to the input tensor.
+ *          The computed sine values are stored in the destination tensor `dst`.
+ *          The operation is optimized using the CANN backend for improved performance.
+ *
+ * @param ctx The CANN context used for operations.
+ * @param dst The destination tensor where the sine values will be stored.
+ *            dst->op is `GGML_OP_SIN`.
+ */
+void ggml_cann_sin(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
 template <aclnnStatus getWorkspaceSize(const aclTensor*, const aclTensor*,
                                        aclTensor*, uint64_t*, aclOpExecutor**),
           aclnnStatus execute(void*, uint64_t, aclOpExecutor*, aclrtStream)>
index 3527bd298a370116825c557637ad79d5d40861b0..5e790f05fbb672a10abd3a6f820cdb5f3b80ef90 100644 (file)
@@ -1420,6 +1420,15 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
         case GGML_OP_ARGSORT:
             ggml_cann_argsort(ctx, dst);
             break;
+        case GGML_OP_ARGMAX:
+            ggml_cann_argmax(ctx, dst);
+            break;
+        case GGML_OP_COS:
+            ggml_cann_cos(ctx, dst);
+            break;
+        case GGML_OP_SIN:
+            ggml_cann_sin(ctx, dst);
+            break;
         default:
             return false;
     }
@@ -1802,6 +1811,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
         case GGML_OP_ARANGE:
         case GGML_OP_TIMESTEP_EMBEDDING:
         case GGML_OP_LEAKY_RELU:
+        case GGML_OP_ARGMAX:
+        case GGML_OP_COS:
+        case GGML_OP_SIN:
             return true;
         default:
             return false;