]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : add CANN backend (llama/0)
authorhipudding <redacted>
Thu, 8 Aug 2024 11:48:06 +0000 (14:48 +0300)
committerGeorgi Gerganov <redacted>
Thu, 8 Aug 2024 11:55:23 +0000 (14:55 +0300)
ggml-ci

22 files changed:
include/ggml-cann.h [new file with mode: 0644]
scripts/sync-llama-am.sh
scripts/sync-llama.sh
scripts/sync-whisper-am.sh
scripts/sync-whisper.sh
src/ggml-cann.cpp [new file with mode: 0644]
src/ggml-cann/Doxyfile [new file with mode: 0644]
src/ggml-cann/acl_tensor.cpp [new file with mode: 0644]
src/ggml-cann/acl_tensor.h [new file with mode: 0644]
src/ggml-cann/aclnn_ops.cpp [new file with mode: 0644]
src/ggml-cann/aclnn_ops.h [new file with mode: 0644]
src/ggml-cann/common.h [new file with mode: 0644]
src/ggml-cann/kernels/CMakeLists.txt [new file with mode: 0644]
src/ggml-cann/kernels/ascendc_kernels.h [new file with mode: 0644]
src/ggml-cann/kernels/dup.cpp [new file with mode: 0644]
src/ggml-cann/kernels/get_row_f16.cpp [new file with mode: 0644]
src/ggml-cann/kernels/get_row_f32.cpp [new file with mode: 0644]
src/ggml-cann/kernels/get_row_q4_0.cpp [new file with mode: 0644]
src/ggml-cann/kernels/get_row_q8_0.cpp [new file with mode: 0644]
src/ggml-cann/kernels/quantize_f16_q8_0.cpp [new file with mode: 0644]
src/ggml-cann/kernels/quantize_f32_q8_0.cpp [new file with mode: 0644]
src/ggml-cann/kernels/quantize_float_to_q4_0.cpp [new file with mode: 0644]

diff --git a/include/ggml-cann.h b/include/ggml-cann.h
new file mode 100644 (file)
index 0000000..ca73211
--- /dev/null
@@ -0,0 +1,125 @@
+/*
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+ * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+ * IN THE SOFTWARE.
+ */
+
+#pragma once
+
+#include "ggml-backend.h"
+#include "ggml.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/**
+ * @brief Maximum number of CANN devices supported.
+ */
+#define GGML_CANN_MAX_DEVICES 16
+
+/**
+ * @brief Initializes the CANN backend for a specified device.
+ *
+ * This function initializes the CANN backend for the given device.
+ * It verifies the device index, allocates a context, and creates a backend
+ * instance.
+ *
+ * @param device The index of the device to initialize.
+ * @return A pointer to the initialized backend instance, or nullptr on failure.
+ */
+GGML_API GGML_CALL ggml_backend_t ggml_backend_cann_init(int32_t device);
+
+/**
+ * @brief Checks if a given backend is a CANN backend.
+ *
+ * This function verifies if the provided backend is a CANN backend by comparing
+ * its GUID with the CANN backend's GUID.
+ *
+ * @param backend The backend instance to check.
+ * @return True if the backend is a CANN backend, false otherwise.
+ */
+GGML_API GGML_CALL bool ggml_backend_is_cann(ggml_backend_t backend);
+
+/**
+ * @brief Retrieves the CANN buffer type for a specified device.
+ *
+ * This function initializes and returns the buffer type interface associated
+ * with the given device. It ensures thread-safe access using a mutex.
+ *
+ * @param device The device index for which to retrieve the buffer type.
+ * @return A pointer to the buffer type interface for the specified device, or
+ * nullptr if the device index is out of range.
+ */
+GGML_API GGML_CALL ggml_backend_buffer_type_t
+ggml_backend_cann_buffer_type(int32_t device);
+
+/**
+ * @brief Retrieves the number of CANN devices available.
+ *
+ * This function returns the number of CANN devices available based on
+ * information obtained from `ggml_cann_info()`.
+ *
+ * @return The number of CANN devices available.
+ */
+GGML_API GGML_CALL int32_t ggml_backend_cann_get_device_count(void);
+
+/**
+ * @brief Retrieves the description of a specific CANN device.
+ *
+ * This function sets the specified device, retrieves the SoC name,
+ * and writes it into the provided description buffer.
+ *
+ * @param device The device index to retrieve the description for.
+ * @param description Pointer to a buffer where the description will be written.
+ * @param description_size Size of the description buffer.
+ */
+GGML_API GGML_CALL void ggml_backend_cann_get_device_description(
+    int32_t device, char* description, size_t description_size);
+
+/**
+ * @brief Retrieves the memory information of a specific CANN device.
+ *
+ * This function sets the specified device, retrieves the free and total
+ * memory information of the specified type (ACL_HBM_MEM), and stores them
+ * in the provided pointers.
+ *
+ * @param device The device index to retrieve memory information for.
+ * @param free Pointer to a variable where the free memory size will be stored.
+ * @param total Pointer to a variable where the total memory size will be
+ * stored.
+ */
+GGML_API GGML_CALL void ggml_backend_cann_get_device_memory(int32_t device,
+                                                            size_t* free,
+                                                            size_t* total);
+
+/**
+ * @brief Set the logging callback for GGML.
+ *
+ * This function sets the logging callback and user data for logging.
+ *
+ * @param log_callback The logging callback to set.
+ * @param user_data User data to pass to the logging callback.
+ */
+GGML_API void ggml_backend_cann_log_set_callback(ggml_log_callback log_callback,
+                                                 void* user_data);
+
+#ifdef __cplusplus
+}
+#endif
index 3370f91f917648a6850a06c01db81b8eeb115351..378ad855d9494b32eb64f3473b27f41925dad2c3 100755 (executable)
@@ -62,6 +62,7 @@ while read c; do
         ggml/src/ggml*.m \
         ggml/src/ggml*.metal \
         ggml/src/ggml*.cu \
+        ggml/src/ggml-cann/* \
         ggml/src/ggml-cuda/* \
         ggml/src/ggml-sycl/* \
         ggml/include/ggml*.h \
@@ -108,6 +109,8 @@ if [ -f $SRC_GGML/llama-src.patch ]; then
     # ggml/src/ggml-backend-impl.h -> src/ggml-backend-impl.h
     # ggml/src/ggml-backend.c      -> src/ggml-backend.c
     # ggml/src/ggml-blas.cpp       -> src/ggml-blas.cpp
+    # ggml/src/ggml-cann/*         -> src/ggml-cann/*
+    # ggml/src/ggml-cann.cpp       -> src/ggml-cann.cpp
     # ggml/src/ggml-common.h       -> src/ggml-common.h
     # ggml/src/ggml-cuda/*         -> src/ggml-cuda/*
     # ggml/src/ggml-cuda.cu        -> src/ggml-cuda.cu
@@ -126,6 +129,7 @@ if [ -f $SRC_GGML/llama-src.patch ]; then
     # ggml/include/ggml-alloc.h   -> include/ggml-alloc.h
     # ggml/include/ggml-backend.h -> include/ggml-backend.h
     # ggml/include/ggml-blas.h    -> include/ggml-blas.h
+    # ggml/include/ggml-cann.h    -> include/ggml-cann.h
     # ggml/include/ggml-cuda.h    -> include/ggml-cuda.h
     # ggml/include/ggml-kompute.h -> include/ggml-kompute.h
     # ggml/include/ggml-metal.h   -> include/ggml-metal.h
@@ -153,6 +157,8 @@ if [ -f $SRC_GGML/llama-src.patch ]; then
         -e 's/\/ggml\/src\/ggml-backend-impl\.h/\/src\/ggml-backend-impl.h/g' \
         -e 's/\/ggml\/src\/ggml-backend\.c/\/src\/ggml-backend.c/g' \
         -e 's/\/ggml\/src\/ggml-blas\.cpp/\/src\/ggml-blas.cpp/g' \
+        -e 's/\/ggml\/src\/ggml-cann\//\/src\/ggml-cann\//g' \
+        -e 's/\/ggml\/src\/ggml-cann\.cpp/\/src\/ggml-cann.cpp/g' \
         -e 's/\/ggml\/src\/ggml-common\.h/\/src\/ggml-common.h/g' \
         -e 's/\/ggml\/src\/ggml-cuda\//\/src\/ggml-cuda\//g' \
         -e 's/\/ggml\/src\/ggml-cuda\.cu/\/src\/ggml-cuda.cu/g' \
@@ -170,6 +176,7 @@ if [ -f $SRC_GGML/llama-src.patch ]; then
         -e 's/\/ggml\/include\/ggml-alloc\.h/\/include\/ggml-alloc.h/g' \
         -e 's/\/ggml\/include\/ggml-backend\.h/\/include\/ggml-backend.h/g' \
         -e 's/\/ggml\/include\/ggml-blas\.h/\/include\/ggml-blas.h/g' \
+        -e 's/\/ggml\/include\/ggml-cann\.h/\/include\/ggml-cann.h/g' \
         -e 's/\/ggml\/include\/ggml-cuda\.h/\/include\/ggml-cuda.h/g' \
         -e 's/\/ggml\/include\/ggml-kompute\.h/\/include\/ggml-kompute.h/g' \
         -e 's/\/ggml\/include\/ggml-metal\.h/\/include\/ggml-metal.h/g' \
index fd00faa6ca6ff645e59b92f56101e6cdec6e5e80..b148ee2e9f0c4cfded3271ef57b3e7c69ccd2751 100755 (executable)
@@ -11,6 +11,8 @@ cp -rpv ../llama.cpp/ggml/src/ggml-alloc.c        src/ggml-alloc.c
 cp -rpv ../llama.cpp/ggml/src/ggml-backend-impl.h src/ggml-backend-impl.h
 cp -rpv ../llama.cpp/ggml/src/ggml-backend.c      src/ggml-backend.c
 cp -rpv ../llama.cpp/ggml/src/ggml-blas.cpp       src/ggml-blas.cpp
+cp -rpv ../llama.cpp/ggml/src/ggml-cann/*         src/ggml-cann/
+cp -rpv ../llama.cpp/ggml/src/ggml-cann.cpp       src/ggml-cann.cpp
 cp -rpv ../llama.cpp/ggml/src/ggml-common.h       src/ggml-common.h
 cp -rpv ../llama.cpp/ggml/src/ggml-cuda/*         src/ggml-cuda/
 cp -rpv ../llama.cpp/ggml/src/ggml-cuda.cu        src/ggml-cuda.cu
@@ -30,6 +32,7 @@ cp -rpv ../llama.cpp/ggml/include/ggml.h         include/ggml.h
 cp -rpv ../llama.cpp/ggml/include/ggml-alloc.h   include/ggml-alloc.h
 cp -rpv ../llama.cpp/ggml/include/ggml-backend.h include/ggml-backend.h
 cp -rpv ../llama.cpp/ggml/include/ggml-blas.h    include/ggml-blas.h
+cp -rpv ../llama.cpp/ggml/include/ggml-cann.h    include/ggml-cann.h
 cp -rpv ../llama.cpp/ggml/include/ggml-cuda.h    include/ggml-cuda.h
 cp -rpv ../llama.cpp/ggml/include/ggml-kompute.h include/ggml-kompute.h
 cp -rpv ../llama.cpp/ggml/include/ggml-metal.h   include/ggml-metal.h
index 3be6aa3026397e7834dcab19140cac524796e462..2f8a7627c28434136f6809f0b62f14f9a55e9c9f 100755 (executable)
@@ -62,6 +62,7 @@ while read c; do
         ggml/src/ggml*.m \
         ggml/src/ggml*.metal \
         ggml/src/ggml*.cu \
+        ggml/src/ggml-cann/* \
         ggml/src/ggml-cuda/* \
         ggml/src/ggml-sycl/* \
         ggml/include/ggml*.h \
@@ -107,6 +108,8 @@ if [ -f $SRC_GGML/whisper-src.patch ]; then
     # ggml/src/ggml-backend-impl.h -> src/ggml-backend-impl.h
     # ggml/src/ggml-backend.c      -> src/ggml-backend.c
     # ggml/src/ggml-blas.cpp       -> src/ggml-blas.cpp
+    # ggml/src/ggml-cann/*         -> src/ggml-cann/*
+    # ggml/src/ggml-cann.cpp       -> src/ggml-cann.cpp
     # ggml/src/ggml-common.h       -> src/ggml-common.h
     # ggml/src/ggml-cuda/*         -> src/ggml-cuda/*
     # ggml/src/ggml-cuda.cu        -> src/ggml-cuda.cu
@@ -125,6 +128,7 @@ if [ -f $SRC_GGML/whisper-src.patch ]; then
     # ggml/include/ggml-alloc.h   -> include/ggml-alloc.h
     # ggml/include/ggml-backend.h -> include/ggml-backend.h
     # ggml/include/ggml-blas.h    -> include/ggml-blas.h
+    # ggml/include/ggml-cann.h    -> include/ggml-cann.h
     # ggml/include/ggml-cuda.h    -> include/ggml-cuda.h
     # ggml/include/ggml-kompute.h -> include/ggml-kompute.h
     # ggml/include/ggml-metal.h   -> include/ggml-metal.h
@@ -151,6 +155,8 @@ if [ -f $SRC_GGML/whisper-src.patch ]; then
         -e 's/\/ggml\/src\/ggml-backend-impl\.h/\/src\/ggml-backend-impl.h/g' \
         -e 's/\/ggml\/src\/ggml-backend\.c/\/src\/ggml-backend.c/g' \
         -e 's/\/ggml\/src\/ggml-blas\.cpp/\/src\/ggml-blas.cpp/g' \
+        -e 's/\/ggml\/src\/ggml-cann\//\/src\/ggml-cann\//g' \
+        -e 's/\/ggml\/src\/ggml-cann\.cpp/\/src\/ggml-cann.cpp/g' \
         -e 's/\/ggml\/src\/ggml-common\.h/\/src\/ggml-common.h/g' \
         -e 's/\/ggml\/src\/ggml-cuda\//\/src\/ggml-cuda\//g' \
         -e 's/\/ggml\/src\/ggml-cuda\.cu/\/src\/ggml-cuda.cu/g' \
@@ -168,6 +174,7 @@ if [ -f $SRC_GGML/whisper-src.patch ]; then
         -e 's/\/ggml\/include\/ggml-alloc\.h/\/include\/ggml-alloc.h/g' \
         -e 's/\/ggml\/include\/ggml-backend\.h/\/include\/ggml-backend.h/g' \
         -e 's/\/ggml\/include\/ggml-blas\.h/\/include\/ggml-blas.h/g' \
+        -e 's/\/ggml\/include\/ggml-cann\.h/\/include\/ggml-cann.h/g' \
         -e 's/\/ggml\/include\/ggml-cuda\.h/\/include\/ggml-cuda.h/g' \
         -e 's/\/ggml\/include\/ggml-kompute\.h/\/include\/ggml-kompute.h/g' \
         -e 's/\/ggml\/include\/ggml-metal\.h/\/include\/ggml-metal.h/g' \
index 27dae2ce5824013b1c9250c6fb874653f206be5c..6a3b61aab92da6213100d9b8a174c933885ad55d 100755 (executable)
@@ -11,6 +11,8 @@ cp -rpv ../whisper.cpp/ggml/src/ggml-alloc.c        src/ggml-alloc.c
 cp -rpv ../whisper.cpp/ggml/src/ggml-backend-impl.h src/ggml-backend-impl.h
 cp -rpv ../whisper.cpp/ggml/src/ggml-backend.c      src/ggml-backend.c
 cp -rpv ../whisper.cpp/ggml/src/ggml-blas.cpp       src/ggml-blas.cpp
+cp -rpv ../whisper.cpp/ggml/src/ggml-cann/*         src/ggml-cann/
+cp -rpv ../whisper.cpp/ggml/src/ggml-cann.cpp       src/ggml-cann.cpp
 cp -rpv ../whisper.cpp/ggml/src/ggml-common.h       src/ggml-common.h
 cp -rpv ../whisper.cpp/ggml/src/ggml-cuda/*         src/ggml-cuda/
 cp -rpv ../whisper.cpp/ggml/src/ggml-cuda.cu        src/ggml-cuda.cu
@@ -30,6 +32,7 @@ cp -rpv ../whisper.cpp/ggml/include/ggml.h         include/ggml.h
 cp -rpv ../whisper.cpp/ggml/include/ggml-alloc.h   include/ggml-alloc.h
 cp -rpv ../whisper.cpp/ggml/include/ggml-backend.h include/ggml-backend.h
 cp -rpv ../whisper.cpp/ggml/include/ggml-blas.h    include/ggml-blas.h
+cp -rpv ../whisper.cpp/ggml/include/ggml-cann.h    include/ggml-cann.h
 cp -rpv ../whisper.cpp/ggml/include/ggml-cuda.h    include/ggml-cuda.h
 cp -rpv ../whisper.cpp/ggml/include/ggml-kompute.h include/ggml-kompute.h
 cp -rpv ../whisper.cpp/ggml/include/ggml-metal.h   include/ggml-metal.h
diff --git a/src/ggml-cann.cpp b/src/ggml-cann.cpp
new file mode 100644 (file)
index 0000000..06930ba
--- /dev/null
@@ -0,0 +1,2020 @@
+/*
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+ * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+ * IN THE SOFTWARE.
+ */
+
+#include "ggml-cann.h"
+
+#include <acl/acl.h>
+#include <stdarg.h>
+
+#include <cmath>
+#include <cstdio>
+#include <cstring>
+#include <mutex>
+
+#include "ggml-backend-impl.h"
+#include "ggml-cann/aclnn_ops.h"
+#include "ggml-cann/common.h"
+
+#define GGML_COMMON_DECL_C
+
+#include "ggml-common.h"
+
+/**
+ * @brief Default logging callback for GGML.
+ *
+ * This function is the default logging callback that logs messages to stderr.
+ *
+ * @param level The log level.
+ * @param msg The log message.
+ * @param user_data User data passed to the callback.
+ */
+static void ggml_cann_default_log_callback(enum ggml_log_level level,
+                                           const char* msg, void* user_data) {
+    GGML_UNUSED(level);
+    GGML_UNUSED(user_data);
+    fprintf(stderr, "%s", msg);
+}
+
+ggml_log_callback ggml_cann_log_callback = ggml_cann_default_log_callback;
+void* ggml_cann_log_user_data = NULL;
+
+GGML_API void ggml_backend_cann_log_set_callback(ggml_log_callback log_callback,
+                                                 void* user_data) {
+    ggml_cann_log_callback = log_callback;
+    ggml_cann_log_user_data = user_data;
+}
+
+#define GGML_CANN_LOG_INFO(...) ggml_cann_log(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
+#define GGML_CANN_LOG_WARN(...) ggml_cann_log(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
+#define GGML_CANN_LOG_ERROR(...) \
+    ggml_cann_log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
+
+GGML_ATTRIBUTE_FORMAT(2, 3)
+
+/**
+ * @brief Log a message using the current logging callback.
+ *
+ * This function formats a log message and passes it to the current logging
+ * callback.
+ *
+ * @param level The log level.
+ * @param format The format string for the log message.
+ * @param ... The arguments for the format string.
+ */
+static void ggml_cann_log(enum ggml_log_level level, const char* format, ...) {
+    if (ggml_cann_log_callback != NULL) {
+        va_list args;
+        va_start(args, format);
+        char buffer[128];
+        int len = vsnprintf(buffer, 128, format, args);
+        if (len < 128) {
+            ggml_cann_log_callback(level, buffer, ggml_cann_log_user_data);
+        } else {
+             // vsnprintf adds a null terminator
+            std::vector<char> buffer2(len + 1);
+            va_end(args);
+            va_start(args, format);
+            vsnprintf(&buffer2[0], buffer2.size(), format, args);
+            ggml_cann_log_callback(level, buffer2.data(),
+                                   ggml_cann_log_user_data);
+        }
+        va_end(args);
+    }
+}
+
+/**
+ * @brief Handles CANN errors by printing an error message and aborting.
+ *
+ * @param stmt The statement that caused the error.
+ * @param func The function in which the error occurred.
+ * @param file The file in which the error occurred.
+ * @param line The line number where the error occurred.
+ * @param msg The error message.
+ */
+[[noreturn]] void ggml_cann_error(const char* stmt, const char* func,
+                                  const char* file, int line, const char* msg) {
+    int32_t id = -1;
+    aclrtGetDevice(&id);
+
+    GGML_CANN_LOG_ERROR("CANN error: %s\n", msg);
+    GGML_CANN_LOG_ERROR("  current device: %d, in function %s at %s:%d\n", id, func,
+            file, line);
+    GGML_CANN_LOG_ERROR("  %s\n", stmt);
+    // abort with GGML_ASSERT to get a stack trace
+    GGML_ABORT("CANN error");
+}
+
+/**
+ * @brief Sets the device to be used by CANN.
+ *
+ * @param device The device ID to set.
+ */
+void ggml_cann_set_device(const int32_t device) {
+    // TODO: uncomment these lines after empty context has fixed.
+    // int current_device;
+    // ACL_CHECK(aclrtGetDevice(&current_device));
+
+    // if (device == current_device) {
+    //   return;
+    // }
+    ACL_CHECK(aclrtSetDevice(device));
+}
+
+/**
+ * @brief Retrieves the current device ID.
+ *
+ * @return The current device ID.
+ */
+int32_t ggml_cann_get_device() {
+    int32_t id;
+    ACL_CHECK(aclrtGetDevice(&id));
+    return id;
+}
+
+/**
+ * @brief Initialize the CANN device information.
+ *
+ * This function initializes the CANN device information by obtaining the
+ * device count and setting the memory allocation granularity for each device.
+ *
+ * @return A structure containing the device information.
+ */
+static ggml_cann_device_info ggml_cann_init() {
+    ggml_cann_device_info info = {};
+
+    aclError err = aclrtGetDeviceCount((uint32_t*)&info.device_count);
+
+    if (err != ACL_SUCCESS) {
+        GGML_CANN_LOG_ERROR("%s: failed to initialize CANN: %s\n",
+                __func__, aclGetRecentErrMsg());
+        return info;
+    }
+
+    GGML_ASSERT(info.device_count <= GGML_CANN_MAX_DEVICES);
+
+    for (int id = 0; id < info.device_count; ++id) {
+        aclrtPhysicalMemProp prop = {};
+        prop.handleType = ACL_MEM_HANDLE_TYPE_NONE;
+        prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
+        prop.memAttr = ACL_HBM_MEM_HUGE;
+        prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
+        prop.location.id = id;
+        prop.reserve = 0;
+        ACL_CHECK(aclrtMemGetAllocationGranularity(
+            &prop, ACL_RT_MEM_ALLOC_GRANULARITY_RECOMMENDED,
+            &info.devices[id].vmm_granularity));
+    }
+
+    // TODO: add more device info later.
+    return info;
+}
+
+/**
+ * @brief Retrieve the CANN device information.
+ *
+ * This function returns a reference to a structure containing the CANN device
+ * information. The device information is initialized once and reused on
+ * subsequent calls.
+ *
+ * @return A reference to the structure containing the device information.
+ */
+const ggml_cann_device_info& ggml_cann_info() {
+    static ggml_cann_device_info info = ggml_cann_init();
+    return info;
+}
+
+//#define DEBUG_CANN_MALLOC
+/**
+ * @brief A pool of CANN buffers(legacy).
+ *
+ * This class manages a pool of CANN buffers for a specific device.
+ */
+struct ggml_cann_pool_leg : public ggml_cann_pool {
+    /**
+     * @brief The maximum number of buffers in the pool.
+     */
+    static const int MAX_BUFFERS = 256;
+
+    /**
+     * @brief The device ID associated with this buffer pool.
+     */
+    int device;
+
+    /**
+     * @brief Structure representing a CANN buffer.
+     */
+    struct ggml_cann_buffer {
+        void* ptr = nullptr;  ///< Pointer to the buffer memory.
+        size_t size = 0;      ///< Size of the buffer.
+    };
+
+    /**
+     * @brief Array of CANN buffers in the pool.
+     */
+    ggml_cann_buffer buffer_pool[MAX_BUFFERS] = {};
+
+    /**
+     * @brief Total size of all buffers in the pool.
+     */
+    size_t pool_size = 0;
+
+    /**
+     * @brief Constructor to initialize the buffer pool for a specific device.
+     *
+     * @param device The device ID to associate with this buffer pool.
+     */
+    explicit ggml_cann_pool_leg(int device) : device(device) {}
+
+    /**
+     * @brief Destructor to free all buffers in the pool.
+     */
+    ~ggml_cann_pool_leg() {
+        ggml_cann_set_device(device);
+        for (int i = 0; i < MAX_BUFFERS; ++i) {
+            ggml_cann_buffer& b = buffer_pool[i];
+            if (b.ptr != nullptr) {
+                ACL_CHECK(aclrtFree(b.ptr));
+                pool_size -= b.size;
+            }
+        }
+        GGML_ASSERT(pool_size == 0);
+    }
+
+    /**
+     * @brief Allocate a buffer of the given size.
+     *
+     * @param size The size of the buffer to allocate.
+     * @param actual_size A pointer to a variable to receive the actual size of
+     * the allocated buffer.
+     * @return A pointer to the allocated buffer.
+     */
+    void* alloc(size_t size, size_t* actual_size) override {
+#ifdef DEBUG_CANN_MALLOC
+        int nnz = 0;
+        size_t max_size = 0;
+#endif
+        size_t best_diff = 1ull << 36;
+        int ibest = -1;
+        for (int i = 0; i < MAX_BUFFERS; ++i) {
+            ggml_cann_buffer& b = buffer_pool[i];
+            if (b.ptr != nullptr) {
+#ifdef DEBUG_CANN_MALLOC
+                ++nnz;
+                if (b.size > max_size) max_size = b.size;
+#endif
+                if (b.size >= size) {
+                    size_t diff = b.size - size;
+                    if (diff < best_diff) {
+                        best_diff = diff;
+                        ibest = i;
+                        if (!best_diff) {
+                            void* ptr = b.ptr;
+                            *actual_size = b.size;
+                            b.ptr = nullptr;
+                            b.size = 0;
+                            return ptr;
+                        }
+                    }
+                }
+            }
+        }
+        if (ibest >= 0) {
+            ggml_cann_buffer& b = buffer_pool[ibest];
+            void* ptr = b.ptr;
+            *actual_size = b.size;
+            b.ptr = nullptr;
+            b.size = 0;
+            return ptr;
+        }
+        void* ptr;
+        size_t look_ahead_size = (size_t)(1.05 * size);
+        look_ahead_size = 256 * ((look_ahead_size + 255) / 256);
+        ggml_cann_set_device(device);
+        ACL_CHECK(
+            aclrtMalloc(&ptr, look_ahead_size, ACL_MEM_MALLOC_HUGE_FIRST));
+        *actual_size = look_ahead_size;
+        pool_size += look_ahead_size;
+#ifdef DEBUG_CANN_MALLOC
+        GGML_CANN_LOG_INFO(
+            "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, "
+            "requested %u MB\n",
+            __func__, device, nnz, (uint32_t)(max_size / 1024 / 1024),
+            (uint32_t)(pool_size / 1024 / 1024),
+            (uint32_t)(size / 1024 / 1024));
+#endif
+        return ptr;
+    }
+
+    /**
+     * @brief Free a buffer and return it to the pool.
+     *
+     * @param ptr Pointer to the buffer to free.
+     * @param size Size of the buffer to free.
+     */
+    void free(void* ptr, size_t size) override {
+        for (int i = 0; i < MAX_BUFFERS; ++i) {
+            ggml_cann_buffer& b = buffer_pool[i];
+            if (b.ptr == nullptr) {
+                b.ptr = ptr;
+                b.size = size;
+                return;
+            }
+        }
+        // memory should always buffered. these memory may still needed by
+        // tasks in stream.
+        // TODO, fix me.
+        GGML_ABORT("Cann buffer pool full, increase MAX_CANN_BUFFERS\n");
+    }
+};
+
+/**
+ * @brief A pool of CANN buffers with virtual memory.
+ *
+ * This class manages a pool of CANN buffers with virtual memory for a specific
+ * device.
+ */
+struct ggml_cann_pool_vmm : public ggml_cann_pool {
+    /**
+     * @brief The maximum size of the virtual memory pool (32 GB).
+     */
+    static const size_t CANN_POOL_VMM_MAX_SIZE = 1ull << 35;  // 32 GB
+
+    /**
+     * @brief The device ID associated with this buffer pool.
+     */
+    int device;
+
+    /**
+     * @brief Pointer to the start of the virtual memory pool.
+     */
+    void* pool_addr = 0;
+
+    /**
+     * @brief Amount of virtual memory used in the pool.
+     */
+    size_t pool_used = 0;
+
+    /**
+     * @brief Total size of the virtual memory pool.
+     */
+    size_t pool_size = 0;
+
+    /**
+     * @brief Allocation granularity for the virtual memory pool.
+     */
+    size_t granularity;
+
+    /**
+     * @brief Handles for the physical memory allocated.
+     */
+    std::vector<aclrtDrvMemHandle> handles;
+
+    /**
+     * @brief Offsets for the mapped memory regions.
+     */
+    std::vector<void*> map_offsets;
+
+    /**
+     * @brief Constructor to initialize the buffer pool with virtual memory for
+     * a specific device.
+     *
+     * @param device The device ID to associate with this buffer pool.
+     */
+    explicit ggml_cann_pool_vmm(int device)
+        : device(device),
+          granularity(ggml_cann_info().devices[device].vmm_granularity) {}
+
+    /**
+     * @brief Destructor to free all buffers in the virtual memory pool.
+     */
+    ~ggml_cann_pool_vmm() {
+        if (pool_addr != 0) {
+            for (auto& offset : map_offsets) {
+                ACL_CHECK(aclrtUnmapMem(offset));
+            }
+            for (auto& handle : handles) {
+                ACL_CHECK(aclrtFreePhysical(handle));
+            }
+            ACL_CHECK(aclrtReleaseMemAddress(pool_addr));
+        }
+    }
+
+    /**
+     * @brief Allocate a buffer of the given size in the virtual memory pool.
+     *
+     * @param size The size of the buffer to allocate.
+     * @param actual_size A pointer to a variable to receive the actual size of
+     * the allocated buffer.
+     * @return A pointer to the allocated buffer.
+     */
+    void* alloc(size_t size, size_t* actual_size) override {
+        // round up the allocation size to the alignment to ensure that all
+        // allocations are aligned for all data types
+        const size_t alignment = 128;
+        size = alignment * ((size + alignment - 1) / alignment);
+
+        size_t avail = pool_size - pool_used;
+
+        if (size > avail) {
+            // round up to the next multiple of the granularity
+            size_t reserve_size = size - avail;
+            reserve_size =
+                granularity * ((reserve_size + granularity - 1) / granularity);
+
+            GGML_ASSERT(pool_size + reserve_size <= CANN_POOL_VMM_MAX_SIZE);
+
+            // allocate more physical memory
+            aclrtPhysicalMemProp prop = {};
+            prop.handleType = ACL_MEM_HANDLE_TYPE_NONE;
+            prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
+            prop.memAttr = ACL_HBM_MEM_HUGE;
+            prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
+            prop.location.id = device;
+            prop.reserve = 0;
+            aclrtDrvMemHandle handle;
+            ACL_CHECK(aclrtMallocPhysical(&handle, reserve_size, &prop, 0));
+
+            // reserve virtual address space (if not already reserved)
+            if (pool_addr == 0) {
+                ACL_CHECK(aclrtReserveMemAddress(
+                    &pool_addr, CANN_POOL_VMM_MAX_SIZE, 0, NULL, 1));
+            }
+
+            // map at the end of the pool
+            ACL_CHECK(aclrtMapMem((char*)pool_addr + pool_size, reserve_size, 0,
+                                  handle, 0));
+
+            handles.push_back(handle);
+            map_offsets.push_back((char*)pool_addr + pool_size);
+
+            // add to the pool
+            pool_size += reserve_size;
+
+            // GGML_CANN_LOG_INFO("cann pool[%d]: size increased to %llu MB (
+            // reserved %llu MB)\n",
+            //       device, (unsigned long long) (pool_size/1024/1024),
+            //       (unsigned long long) (reserve_size/1024/1024));
+        }
+
+        GGML_ASSERT(pool_addr != 0);
+
+        void* ptr = (void*)((char*)pool_addr + pool_used);
+        *actual_size = size;
+        pool_used += size;
+
+#ifdef DEBUG_CANN_MALLOC
+        GGML_CANN_LOG_INFO("cann pool[%d]: allocated %llu bytes at %llx\n", device,
+               (unsigned long long)size, (unsigned long long)ptr);
+#endif
+        return ptr;
+    }
+
+    /**
+     * @brief Free a buffer and return it to the virtual memory pool.
+     *
+     * @param ptr Pointer to the buffer to free.
+     * @param size Size of the buffer to free.
+     */
+    void free(void* ptr, size_t size) override {
+#ifdef DEBUG_CANN_MALLOC
+        GGML_CANN_LOG_INFO("cann pool[%d]: freed %llu bytes at %llx\n", device,
+               (unsigned long long)size, (unsigned long long)ptr);
+#endif
+
+        pool_used -= size;
+
+        // all deallocations must be in reverse order of the allocations
+        GGML_ASSERT(ptr == (void*)((char*)pool_addr + pool_used));
+    }
+};
+
+/**
+ * @brief Create a new CANN pool for a specific device.
+ *
+ * Factory method to create a new CANN pool object based on the device type.
+ *
+ * @param device The device ID for which to create the pool.
+ * @return A unique pointer to the created CANN pool.
+ */
+std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
+    int device) {
+    // return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_leg(device));
+    return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device));
+}
+
+// cann buffer
+/**
+ * @brief Context for managing a CANN buffer associated with a specific device.
+ *
+ * This structure holds information about a CANN buffer, including the device
+ * ID, device pointer, and a name derived from GGML_CANN_NAME and the device ID.
+ */
+struct ggml_backend_cann_buffer_context {
+    int32_t device;  ///< The device ID associated with this buffer context.
+    void* dev_ptr =
+        nullptr;  ///< Pointer to the device memory allocated for the buffer.
+
+    /**
+     * @brief Constructor to initialize the CANN buffer context.
+     *
+     * @param device The device ID associated with this buffer context.
+     * @param dev_ptr Pointer to the device memory allocated for the buffer.
+     */
+    ggml_backend_cann_buffer_context(int32_t device, void* dev_ptr)
+        : device(device),
+          dev_ptr(dev_ptr) {}
+
+    /**
+     * @brief Destructor to free the device memory allocated for the buffer.
+     */
+    ~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr)); }
+};
+
+/**
+ * @brief Retrieve the name associated with a CANN buffer.
+ *
+ * This function returns the name of a CANN buffer, which is stored in the
+ * context of the buffer.
+ *
+ * @param buffer The CANN buffer whose name is to be retrieved.
+ * @return A pointer to a C-string containing the name of the buffer.
+ */
+
+GGML_CALL static const char* ggml_backend_cann_buffer_get_name(
+    ggml_backend_buffer_t buffer) {
+    return "CANN";
+
+    GGML_UNUSED(buffer);
+}
+
+/**
+ * @brief Check if a buffer is a CANN buffer.
+ *
+ * This function checks if a given buffer is a CANN buffer by comparing its
+ * `get_name` function pointer to `ggml_backend_cann_buffer_get_name`.
+ *
+ * @param buffer The buffer to check.
+ * @return true if the buffer is a CANN buffer, false otherwise.
+ */
+GGML_CALL static bool ggml_backend_buffer_is_cann(
+    ggml_backend_buffer_t buffer) {
+    return buffer->iface.get_name == ggml_backend_cann_buffer_get_name;
+}
+
+/**
+ * @brief Free resources associated with a CANN buffer.
+ *
+ * This function frees the resources associated with a CANN buffer, including
+ * its context.
+ *
+ * @param buffer The CANN buffer to free.
+ */
+GGML_CALL static void ggml_backend_cann_buffer_free_buffer(
+    ggml_backend_buffer_t buffer) {
+    ggml_backend_cann_buffer_context* ctx =
+        (ggml_backend_cann_buffer_context*)buffer->context;
+    delete ctx;
+}
+
+/**
+ * @brief Retrieve the base pointer of a CANN buffer.
+ *
+ * This function returns the base pointer of a CANN buffer, which points to the
+ * device memory allocated for the buffer.
+ *
+ * @param buffer The CANN buffer whose base pointer is to be retrieved.
+ * @return A pointer to the base of the device memory allocated for the buffer.
+ */
+GGML_CALL static void* ggml_backend_cann_buffer_get_base(
+    ggml_backend_buffer_t buffer) {
+    ggml_backend_cann_buffer_context* ctx =
+        (ggml_backend_cann_buffer_context*)buffer->context;
+    return ctx->dev_ptr;
+}
+
+/**
+ * @brief Transform quantized Q4.0 tensor data into a format suitable for CANN
+ * processing.
+ *
+ * This function transforms quantized Q4.0 tensor data into a format suitable
+ * for CANN processing. It extracts quantization values and scales from the
+ * source data and prepares them in a format expected by CANN operations.
+ *
+ * @param tensor Pointer to the tensor information.
+ * @param src Pointer to the source data in Q4.0 format.
+ * @param dst Pointer to the destination buffer where transformed data will be
+ * stored.
+ */
+GGML_CALL static void ggml_backend_cann_transform_q4_0(ggml_tensor* tensor,
+                                                       const void* src,
+                                                       void* dst) {
+
+    int64_t n_elems = ggml_nelements(tensor);
+    int64_t groups = n_elems / QK4_0;
+    size_t quant_bytes = n_elems * sizeof(uint8_t) / 2;
+
+    uint8_t* quant_offset = (uint8_t*)dst;
+    uint16_t* scale_offset = (uint16_t*)((char*)dst + quant_bytes);
+
+    for (int i = 0; i < groups; i++) {
+        const block_q4_0* group =
+            (const block_q4_0*)((const char*)src + i * sizeof(block_q4_0));
+        *scale_offset = group->d;
+        scale_offset++;
+
+        // 0-15
+        for (int j = 0; j < QK4_0 / 2; j += 2) {
+            (*quant_offset) = (group->qs[j] & 0x0F);
+            (*quant_offset) |= ((group->qs[j + 1] << 4));
+            quant_offset++;
+        }
+
+        // 16-31
+        for (int j = 0; j < QK4_0 / 2; j += 2) {
+            (*quant_offset) = (group->qs[j] >> 4);
+            (*quant_offset) |= (group->qs[j + 1] & 0xF0);
+            quant_offset++;
+        }
+    }
+
+    // put (uint4b_t -8) into int4b_t
+    for (quant_offset = (uint8_t*)dst;
+         quant_offset < (uint8_t*)dst + quant_bytes; quant_offset++) {
+        (*quant_offset) ^= 0x88;
+    }
+}
+
+/**
+ * @brief Transform CANN processed data back into quantized Q4.0 format.
+ *
+ * This function transforms CANN processed data back into quantized Q4.0 format.
+ * It reverses the transformation performed by
+ * ggml_backend_cann_transform_q4_0(), converting the data back into its
+ * original quantized form.
+ *
+ * @param tensor Pointer to the tensor information.
+ * @param src Pointer to the source buffer containing transformed data.
+ * @param dst Pointer to the destination buffer where the Q4.0 formatted data
+ * will be stored.
+ */
+GGML_CALL static void ggml_backend_cann_transform_back_q4_0(
+    const ggml_tensor* tensor, void* src, void* dst) {
+
+    int64_t n_elems = ggml_nelements(tensor);
+    int64_t groups = n_elems / QK4_0;
+    size_t quant_bytes = n_elems * sizeof(uint8_t) / 2;
+
+    uint8_t* quant_offset = (uint8_t*)src;
+    uint16_t* scale_offset = (uint16_t*)((char*)src + quant_bytes);
+
+    for (; quant_offset < (uint8_t*)src + quant_bytes; quant_offset++) {
+        (*quant_offset) ^= 0x88;
+    }
+    quant_offset = (uint8_t*)src;
+
+    for (int i = 0; i < groups; i++) {
+        block_q4_0* group = (block_q4_0*)((char*)dst + i * sizeof(block_q4_0));
+        group->d = *scale_offset;
+        scale_offset++;
+
+        // 0-15
+        for (int j = 0; j < QK4_0 / 2; j += 2) {
+            group->qs[j] = ((*quant_offset) & 0x0F);
+            group->qs[j + 1] = ((*quant_offset) >> 4);
+            quant_offset++;
+        }
+
+        // 16-31
+        for (int j = 0; j < QK4_0 / 2; j += 2) {
+            group->qs[j] |= ((*quant_offset) << 4);
+            group->qs[j + 1] |= ((*quant_offset) & 0xF0);
+            quant_offset++;
+        }
+    }
+}
+
+/**
+ * @brief Transform quantized Q8.0 tensor data into a format suitable for CANN
+ * processing.
+ *
+ * This function transforms quantized Q8.0 tensor data into a format suitable
+ * for CANN processing. It extracts quantization values and scales from the
+ * source data and prepares them in a format expected by CANN operations.
+ *
+ * @param tensor Pointer to the tensor information.
+ * @param src Pointer to the source data in Q8.0 format.
+ * @param dst Pointer to the destination buffer where transformed data will be
+ * stored.
+ */
+GGML_CALL static void ggml_backend_cann_transform_q8_0(ggml_tensor* tensor,
+                                                       const void* src,
+                                                       void* dst) {
+    int64_t n_elems = ggml_nelements(tensor);
+    int64_t groups = n_elems / QK8_0;
+    size_t quant_bytes = n_elems * sizeof(uint8_t);
+
+    uint8_t* quant_offset = (uint8_t*)dst;
+    uint16_t* scale_offset = (uint16_t*)((char*)dst + quant_bytes);
+
+    for (int i = 0; i < groups; i++) {
+        const block_q8_0* group =
+            (const block_q8_0*)((const char*)src + i * sizeof(block_q8_0));
+        *scale_offset = group->d;
+        scale_offset++;
+        size_t group_quant_size = QK8_0 * sizeof(uint8_t);
+        memcpy(quant_offset, group->qs, group_quant_size);
+        quant_offset += group_quant_size;
+    }
+}
+
+/**
+ * @brief Transform CANN processed data back into quantized Q8.0 format.
+ *
+ * This function transforms CANN processed data back into quantized Q8.0 format.
+ * It reverses the transformation performed by
+ * ggml_backend_cann_transform_q8_0(), converting the data back into its
+ * original quantized form.
+ *
+ * @param tensor Pointer to the tensor information.
+ * @param src Pointer to the source buffer containing transformed data.
+ * @param dst Pointer to the destination buffer where the Q8.0 formatted data
+ * will be stored.
+ */
+GGML_CALL static void ggml_backend_cann_transform_back_q8_0(
+    const ggml_tensor* tensor, const void* src, void* dst) {
+    int64_t n_elems = ggml_nelements(tensor);
+    int64_t groups = n_elems / QK8_0;
+    size_t quant_bytes = n_elems * sizeof(uint8_t);
+
+    const uint8_t* quant_offset = (const uint8_t*)src;
+    const uint16_t* scale_offset =
+        (const uint16_t*)((const char*)src + quant_bytes);
+
+    for (int i = 0; i < groups; i++) {
+        block_q8_0* group = (block_q8_0*)((char*)dst + i * sizeof(block_q8_0));
+        group->d = *scale_offset;
+        scale_offset++;
+        size_t group_quant_size = QK8_0 * sizeof(uint8_t);
+        memcpy(group->qs, quant_offset, group_quant_size);
+        quant_offset += group_quant_size;
+    }
+}
+
+/**
+ * @brief Transform tensor data based on its type for CANN processing.
+ *
+ * This function transforms tensor data based on its quantization type for CANN
+ * processing. It dispatches the transformation based on the tensor's type to
+ * specialized functions handling Q4.0 and Q8.0 formats.
+ *
+ * @param tensor Pointer to the tensor information.
+ * @param src Pointer to the source data to be transformed.
+ * @param dst Pointer to the destination buffer where transformed data will be
+ * stored.
+ */
+GGML_CALL static void ggml_backend_cann_transform(ggml_tensor* tensor,
+                                                  const void* src, void* dst) {
+    switch (tensor->type) {
+        case GGML_TYPE_Q4_0:
+            ggml_backend_cann_transform_q4_0(tensor, src, dst);
+            break;
+        case GGML_TYPE_Q8_0:
+            ggml_backend_cann_transform_q8_0(tensor, src, dst);
+            break;
+        default:
+            break;
+    }
+}
+
+/**
+ * @brief Transform CANN processed data back into tensor data based on its type.
+ *
+ * This function transforms CANN processed data back into tensor data based on
+ * its quantization type for Q4.0 and Q8.0 formats. It dispatches the
+ * transformation based on the tensor's type to specialized functions.
+ *
+ * @param tensor Pointer to the tensor information.
+ * @param src Pointer to the source data containing CANN processed data.
+ * @param dst Pointer to the destination buffer where transformed tensor data
+ * will be stored.
+ */
+GGML_CALL static void ggml_backend_cann_transform_back(
+    const ggml_tensor* tensor, void* src, void* dst) {
+    switch (tensor->type) {
+        case GGML_TYPE_Q4_0:
+            ggml_backend_cann_transform_back_q4_0(tensor, src, dst);
+            break;
+        case GGML_TYPE_Q8_0:
+            ggml_backend_cann_transform_back_q8_0(tensor, src, dst);
+            break;
+        default:
+            break;
+    }
+}
+
+/**
+ * @brief Check if transformation is needed for a given tensor type.
+ *
+ * This function checks if transformation is needed for a given tensor type
+ * to prepare data for CANN processing.
+ *
+ * @param type The tensor type to check.
+ * @return true if transformation is needed, false otherwise.
+ */
+GGML_CALL static bool need_transform(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q8_0:
+            return true;
+        default:
+            return false;
+    }
+}
+
+/**
+ * @brief Initialize a tensor using data from a CANN buffer.
+ *
+ * This function initializes a tensor using data from a CANN buffer.
+ * It handles special cases such as views and quantization.
+ *
+ * @param buffer The CANN buffer from which to initialize the tensor.
+ * @param tensor Pointer to the tensor to be initialized.
+ */
+GGML_CALL static void ggml_backend_cann_buffer_init_tensor(
+    ggml_backend_buffer_t buffer, ggml_tensor* tensor) {
+    if (tensor->view_src != NULL && tensor->view_offs == 0) {
+        GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);
+        return;
+    }
+
+    // TODO: can backend doesn't support quantized yet. Just leave the code
+    // here.
+    if (ggml_is_quantized(tensor->type)) {
+        // Initialize padding to 0 to avoid possible NaN values
+        size_t original_size = ggml_nbytes(tensor);
+        size_t padded_size =
+            ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
+
+        if (padded_size > original_size && tensor->view_src == nullptr) {
+            size_t memset_size = padded_size - original_size;
+            ACL_CHECK(aclrtMemset((char*)tensor->data + original_size,
+                                  memset_size, 0, memset_size));
+        }
+    }
+}
+
+// TODO: need handle tensor which has paddings.
+/**
+ * @brief Set tensor data in a CANN buffer.
+ *
+ * This function sets tensor data in a CANN buffer, handling transformations
+ * if needed based on the tensor's type.
+ *
+ * @param buffer The CANN buffer where the tensor data will be set.
+ * @param tensor Pointer to the tensor whose data will be set.
+ * @param data Pointer to the source data to be copied into the tensor.
+ * @param offset Offset in the source data from where to start copying.
+ * @param size Size of the data to be copied, in bytes.
+ */
+GGML_CALL static void ggml_backend_cann_buffer_set_tensor(
+    ggml_backend_buffer_t buffer, ggml_tensor *tensor, const void *data,
+    size_t offset, size_t size) {
+    ggml_backend_cann_buffer_context *ctx =
+        (ggml_backend_cann_buffer_context *)buffer->context;
+
+    ggml_cann_set_device(ctx->device);
+    // TODO: refer to cann(#6017), it use thread's default stream.
+    // For acl, synchronous functions use this default stream.
+    // Why aclrtSynchronizeDevice?
+
+    if (!need_transform(tensor->type)) {
+        ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size, data, size,
+                              ACL_MEMCPY_HOST_TO_DEVICE));
+    } else {
+        void *transform_buffer = malloc(size);
+        ggml_backend_cann_transform(tensor, data, transform_buffer);
+
+#ifndef NDEBUG
+        void *check_buffer = malloc(size);
+        ggml_backend_cann_transform_back(tensor, transform_buffer,
+                                         check_buffer);
+        GGML_ASSERT(memcmp(data, check_buffer, size) == 0);
+        free(check_buffer);
+#endif
+        ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size,
+                              transform_buffer, size,
+                              ACL_MEMCPY_HOST_TO_DEVICE));
+        free(transform_buffer);
+    }
+}
+
+/**
+ * @brief Get tensor data from a CANN buffer.
+ *
+ * This function retrieves tensor data from a CANN buffer, handling
+ * transformations if needed based on the tensor's type.
+ *
+ * @param buffer The CANN buffer from which to retrieve tensor data.
+ * @param tensor Pointer to the tensor whose data will be retrieved.
+ * @param data Pointer to the destination buffer where the tensor data will be
+ * copied.
+ * @param offset Offset in the destination buffer where to start copying.
+ * @param size Size of the data to be copied, in bytes.
+ */
+GGML_CALL static void ggml_backend_cann_buffer_get_tensor(
+    ggml_backend_buffer_t buffer, const ggml_tensor* tensor, void* data,
+    size_t offset, size_t size) {
+    ggml_backend_cann_buffer_context* ctx =
+        (ggml_backend_cann_buffer_context*)buffer->context;
+
+    ggml_cann_set_device(ctx->device);
+
+    if (!need_transform(tensor->type)) {
+        ACL_CHECK(aclrtMemcpy(data, size, (char*)tensor->data + offset, size,
+                              ACL_MEMCPY_DEVICE_TO_HOST));
+    } else {
+        void* transform_buffer = malloc(size);
+        ACL_CHECK(aclrtMemcpy(transform_buffer, size,
+                              (char*)tensor->data + offset, size,
+                              ACL_MEMCPY_DEVICE_TO_HOST));
+        ggml_backend_cann_transform_back(tensor, transform_buffer, data);
+        free(transform_buffer);
+    }
+}
+
+/**
+ * @brief Copy tensor data between CANN buffers if possible.
+ *
+ * This function copies tensor data between CANN buffers if the source and
+ * destination buffers are CANN buffers and they meet the necessary conditions
+ * (same device or devices can access each other).
+ *
+ * @param buffer The destination CANN buffer where the tensor data will be
+ * copied.
+ * @param src Pointer to the source tensor whose data will be copied.
+ * @param dst Pointer to the destination tensor where the data will be copied.
+ * @return true if the copy operation succeeded, false otherwise.
+ */
+GGML_CALL static bool ggml_backend_cann_buffer_cpy_tensor(
+    ggml_backend_buffer_t buffer, const ggml_tensor* src, ggml_tensor* dst) {
+    if (ggml_backend_buffer_is_cann(src->buffer)) {
+        ggml_backend_cann_buffer_context* src_ctx =
+            (ggml_backend_cann_buffer_context*)src->buffer->context;
+        ggml_backend_cann_buffer_context* dst_ctx =
+            (ggml_backend_cann_buffer_context*)buffer->context;
+
+        size_t memcpy_size = ggml_nbytes(src);
+        // Same device.
+        if (src_ctx->device == dst_ctx->device) {
+            ACL_CHECK(aclrtMemcpy((char*)dst->data, memcpy_size,
+                                  (const char*)src->data, memcpy_size,
+                                  ACL_MEMCPY_DEVICE_TO_DEVICE));
+            return true;
+        } else {
+            // Different device but can access by peer.
+            int32_t canAccessPeer = 0;
+            ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, src_ctx->device,
+                                               dst_ctx->device));
+            if (canAccessPeer) {
+                ggml_cann_set_device(src_ctx->device);
+                ACL_CHECK(aclrtDeviceEnablePeerAccess(dst_ctx->device, 0));
+                ACL_CHECK(aclrtMemcpy((char*)dst->data, memcpy_size,
+                                      (const char*)src->data, memcpy_size,
+                                      ACL_MEMCPY_DEVICE_TO_DEVICE));
+                return true;
+            }
+        }
+    }
+    return false;
+}
+
+/**
+ * @brief Clear a CANN buffer by setting all its memory to a specified value.
+ *
+ * This function clears a CANN buffer by setting all its memory to a specified
+ * value.
+ *
+ * @param buffer The CANN buffer to be cleared.
+ * @param value The value to which each byte in the buffer will be set.
+ */
+GGML_CALL static void ggml_backend_cann_buffer_clear(
+    ggml_backend_buffer_t buffer, uint8_t value) {
+    ggml_backend_cann_buffer_context* ctx =
+        (ggml_backend_cann_buffer_context*)buffer->context;
+
+    ggml_cann_set_device(ctx->device);
+    ACL_CHECK(aclrtMemset(ctx->dev_ptr, buffer->size, value, buffer->size));
+}
+
+/**
+ * @brief Interface for a CANN buffer in the backend.
+ *
+ * This structure defines function pointers to operations that can be performed
+ * on a CANN buffer within the backend.
+ */
+static ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
+    /* .get_name        = */ ggml_backend_cann_buffer_get_name,
+    /* .free_buffer     = */ ggml_backend_cann_buffer_free_buffer,
+    /* .get_base        = */ ggml_backend_cann_buffer_get_base,
+    /* .init_tensor     = */ ggml_backend_cann_buffer_init_tensor,
+    /* .set_tensor      = */ ggml_backend_cann_buffer_set_tensor,
+    /* .get_tensor      = */ ggml_backend_cann_buffer_get_tensor,
+    /* .cpy_tensor      = */ ggml_backend_cann_buffer_cpy_tensor,
+    /* .clear           = */ ggml_backend_cann_buffer_clear,
+    /* .reset           = */ NULL,
+};
+
+// cann buffer type
+/**
+ * @brief Structure representing context information for a specific backend
+ * buffer type.
+ */
+struct ggml_backend_cann_buffer_type_context {
+    int32_t
+        device; /**< Device identifier associated with the buffer context. */
+    std::string name; /**< Name associated with the buffer context. */
+};
+
+/**
+ * @brief Retrieves the name associated with a CANN buffer type.
+ *
+ * This function returns the descriptive name associated with the specified
+ * CANN buffer type context.
+ *
+ * @param buft Pointer to the buffer type context.
+ * @return Const pointer to the C-style string containing the name.
+ */
+GGML_CALL static const char* ggml_backend_cann_buffer_type_name(
+    ggml_backend_buffer_type_t buft) {
+    return "CANN";
+
+    GGML_UNUSED(buft);
+}
+
+/**
+ * @brief Allocates a new CANN buffer of the specified type and size.
+ *
+ * This function allocates a new CANN buffer on the specified device with the
+ * given size.
+ *
+ * @param buft Pointer to the buffer type context.
+ * @param size Size in bytes of the buffer to allocate.
+ * @return Pointer to the allocated buffer, or nullptr if allocation fails.
+ */
+GGML_CALL static ggml_backend_buffer_t
+ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
+                                           size_t size) {
+    ggml_backend_cann_buffer_type_context* buft_ctx =
+        (ggml_backend_cann_buffer_type_context*)buft->context;
+
+    ggml_cann_set_device(buft_ctx->device);
+
+    size = std::max(size, (size_t)1);
+
+    void* dev_ptr;
+    aclError err = aclrtMalloc(&dev_ptr, size, ACL_MEM_MALLOC_HUGE_FIRST);
+    if (err != ACL_SUCCESS) {
+        GGML_CANN_LOG_ERROR(
+            "%s: allocating %.2f MiB on device %d: aclrtMalloc failed: %s\n",
+            __func__, size / 1024.0 / 1024.0, buft_ctx->device,
+            aclGetRecentErrMsg());
+        return nullptr;
+    }
+
+    ggml_backend_cann_buffer_context* ctx =
+        new ggml_backend_cann_buffer_context(buft_ctx->device, dev_ptr);
+
+    return ggml_backend_buffer_init(buft, ggml_backend_cann_buffer_interface,
+                                    ctx, size);
+}
+
+/**
+ * @brief Retrieves the memory alignment requirement for CANN buffers of this
+ * type.
+ *
+ * This function returns the alignment requirement in bytes for memory allocated
+ * by the CANN buffer type.
+ *
+ * @param buft Pointer to the buffer type context (unused in this
+ * implementation).
+ * @return The alignment requirement in bytes (fixed at 128 bytes for CANN
+ * buffers).
+ */
+GGML_CALL static size_t ggml_backend_cann_buffer_type_get_alignment(
+    ggml_backend_buffer_type_t buft) {
+    return 128;
+
+    GGML_UNUSED(buft);
+}
+
+/**
+ * @brief Calculates the allocation size required for a tensor in a CANN buffer.
+ *
+ * Computes the total allocation size needed for storing the tensor's data in a
+ * CANN buffer, considering any necessary padding or adjustments for quantized
+ * types.
+ *
+ * @param buft Pointer to the buffer type context (unused in this
+ * implementation).
+ * @param tensor Pointer to the tensor for which the allocation size is
+ * calculated.
+ * @return The total allocation size in bytes required for the tensor in the
+ * CANN buffer.
+ */
+GGML_CALL static size_t ggml_backend_cann_buffer_type_get_alloc_size(
+    ggml_backend_buffer_type_t buft, const ggml_tensor* tensor) {
+    size_t size = ggml_nbytes(tensor);
+    int64_t ne0 = tensor->ne[0];
+
+    // last line must bigger than 32, because every single op deal at
+    // least 32 bytes.
+    // TODO: quantized type?
+    // int64_t line_size = ne0 * ggml_element_size(tensor);
+    // int64_t line_size_align_32 = (line_size + 31) & ~31;
+    // size += (line_size_align_32 - line_size);
+
+    // TODO: not support quantized yet.
+    // TODO: consider un-continue tensor.
+    if (ggml_is_quantized(tensor->type)) {
+        if (ne0 % MATRIX_ROW_PADDING != 0) {
+            size += ggml_row_size(
+                tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
+        }
+    }
+
+    return size;
+
+    GGML_UNUSED(buft);
+}
+
+/**
+ * @brief Interface for managing CANN buffer types in the GGML backend.
+ *
+ * Provides function pointers for allocating, querying properties, and managing
+ * memory for CANN buffer types in the GGML backend.
+ */
+static ggml_backend_buffer_type_i ggml_backend_cann_buffer_type_interface = {
+    /* .get_name         = */ ggml_backend_cann_buffer_type_name,
+    /* .alloc_buffer     = */ ggml_backend_cann_buffer_type_alloc_buffer,
+    /* .get_alignment    = */ ggml_backend_cann_buffer_type_get_alignment,
+    /* .get_max_size     = */ NULL,  // defaults to SIZE_MAX
+    /* .get_alloc_size   = */ ggml_backend_cann_buffer_type_get_alloc_size,
+    /* .is_host          = */ NULL,
+};
+
+/**
+ * @brief Retrieves the CANN buffer type for a specified device.
+ *
+ * This function initializes and returns the buffer type interface associated
+ * with the given device. It ensures thread-safe access using a mutex.
+ *
+ * @param device The device index for which to retrieve the buffer type.
+ * @return A pointer to the buffer type interface for the specified device, or
+ * nullptr if the device index is out of range.
+ */
+GGML_CALL ggml_backend_buffer_type_t
+ggml_backend_cann_buffer_type(int32_t device) {
+    static std::mutex mutex;
+    std::lock_guard<std::mutex> lock(mutex);
+
+    if (device >= ggml_backend_cann_get_device_count()) {
+        return nullptr;
+    }
+
+    static ggml_backend_buffer_type
+        ggml_backend_cann_buffer_types[GGML_CANN_MAX_DEVICES];
+
+    static bool ggml_backend_cann_buffer_type_initialized = false;
+
+    if (!ggml_backend_cann_buffer_type_initialized) {
+        for (int32_t i = 0; i < GGML_CANN_MAX_DEVICES; i++) {
+            ggml_backend_cann_buffer_types[i] = {
+                /* .iface    = */ ggml_backend_cann_buffer_type_interface,
+                /* .context  = */
+                 new ggml_backend_cann_buffer_type_context{
+                    i, "CANN" + std::to_string(i)},
+            };
+        }
+        ggml_backend_cann_buffer_type_initialized = true;
+    }
+
+    return &ggml_backend_cann_buffer_types[device];
+}
+
+/**
+ * @brief Computes the forward operation for a given tensor using CANN
+ * operations.
+ *
+ * This function selects the appropriate CANN operation based on the type of
+ * operation specified in the tensor and performs the computation.
+ *
+ * @param ctx The CANN context containing necessary resources and
+ * configurations.
+ * @param dst The destination tensor where the result of the computation will be
+ * stored.
+ * @return true if the computation was successful; false otherwise.
+ */
+static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
+                                      struct ggml_tensor* dst) {
+    switch (dst->op) {
+        case GGML_OP_REPEAT:
+            ggml_cann_repeat(ctx, dst);
+            break;
+        case GGML_OP_GET_ROWS:
+            ggml_cann_get_rows(ctx, dst);
+            break;
+        case GGML_OP_DUP:
+            ggml_cann_dup(ctx, dst);
+            break;
+        case GGML_OP_ADD:
+            ggml_cann_add(ctx, dst);
+            break;
+        case GGML_OP_ACC:
+            ggml_cann_acc(ctx, dst);
+            break;
+        case GGML_OP_MUL:
+            ggml_cann_mul_div<aclnnMulGetWorkspaceSize, aclnnMul>(ctx, dst);
+            break;
+        case GGML_OP_DIV:
+            ggml_cann_mul_div<aclnnDivGetWorkspaceSize, aclnnDiv>(ctx, dst);
+            break;
+        case GGML_OP_UNARY:
+            switch (ggml_get_unary_op(dst)) {
+                case GGML_UNARY_OP_GELU:
+                    ggml_cann_activation<aclnnGeluGetWorkspaceSize, aclnnGelu>(
+                        ctx, dst);
+                    break;
+                case GGML_UNARY_OP_SILU:
+                    ggml_cann_activation<aclnnSiluGetWorkspaceSize, aclnnSilu>(
+                        ctx, dst);
+                    break;
+                // TODO: Use faster gelu??
+                case GGML_UNARY_OP_GELU_QUICK:
+                    ggml_cann_activation<aclnnGeluGetWorkspaceSize, aclnnGelu>(
+                        ctx, dst);
+                    break;
+                case GGML_UNARY_OP_TANH:
+                    ggml_cann_activation<aclnnTanhGetWorkspaceSize, aclnnTanh>(
+                        ctx, dst);
+                    break;
+                case GGML_UNARY_OP_RELU:
+                    ggml_cann_activation<aclnnReluGetWorkspaceSize, aclnnRelu>(
+                        ctx, dst);
+                    break;
+                case GGML_UNARY_OP_HARDSIGMOID:
+                    ggml_cann_activation<aclnnHardsigmoidGetWorkspaceSize,
+                                         aclnnHardsigmoid>(ctx, dst);
+                    break;
+                case GGML_UNARY_OP_HARDSWISH:
+                    ggml_cann_activation<aclnnHardswishGetWorkspaceSize,
+                                         aclnnHardswish>(ctx, dst);
+                    break;
+                default:
+                    return false;
+            }
+            break;
+        case GGML_OP_NORM:
+            ggml_cann_norm(ctx, dst);
+            break;
+        case GGML_OP_GROUP_NORM:
+            ggml_cann_group_norm(ctx, dst);
+            break;
+        case GGML_OP_CONCAT:
+            ggml_cann_concat(ctx, dst);
+            break;
+        case GGML_OP_UPSCALE:
+            ggml_cann_upsample_nearest2d(ctx, dst);
+            break;
+        case GGML_OP_PAD:
+            ggml_cann_pad(ctx, dst);
+            break;
+        case GGML_OP_ARANGE:
+            ggml_cann_arange(ctx, dst);
+            break;
+        case GGML_OP_TIMESTEP_EMBEDDING:
+            ggml_cann_timestep_embedding(ctx, dst);
+            break;
+        case GGML_OP_LEAKY_RELU:
+            ggml_cann_leaky_relu(ctx, dst);
+            break;
+        case GGML_OP_RMS_NORM:
+            ggml_cann_rms_norm(ctx, dst);
+            break;
+        case GGML_OP_MUL_MAT:
+            ggml_cann_mul_mat(ctx, dst);
+            break;
+        case GGML_OP_MUL_MAT_ID:
+            return false;
+        case GGML_OP_SCALE:
+            ggml_cann_scale(ctx, dst);
+            break;
+        case GGML_OP_SQR:
+            ggml_cann_sqr(ctx, dst);
+            break;
+        case GGML_OP_CLAMP:
+            ggml_cann_clamp(ctx, dst);
+            break;
+        case GGML_OP_CPY:
+            ggml_cann_cpy(ctx, dst);
+            break;
+        case GGML_OP_CONT:
+            ggml_cann_dup(ctx, dst);
+            break;
+        case GGML_OP_NONE:
+        case GGML_OP_RESHAPE:
+        case GGML_OP_VIEW:
+        case GGML_OP_PERMUTE:
+        case GGML_OP_TRANSPOSE:
+            break;
+        case GGML_OP_DIAG_MASK_INF:
+            ggml_cann_diag_mask(ctx, dst, -INFINITY);
+            break;
+        case GGML_OP_SOFT_MAX:
+            ggml_cann_softmax(ctx, dst);
+            break;
+        case GGML_OP_ROPE:
+            ggml_cann_rope(ctx, dst);
+            break;
+        case GGML_OP_IM2COL:
+            ggml_cann_im2col(ctx, dst);
+            break;
+        case GGML_OP_POOL_2D:
+            ggml_cann_pool2d(ctx, dst);
+            break;
+        case GGML_OP_SUM_ROWS:
+            ggml_cann_sum_rows(ctx, dst);
+            break;
+        case GGML_OP_ARGSORT:
+            ggml_cann_argsort(ctx, dst);
+            break;
+        default:
+            return false;
+    }
+
+    return true;
+}
+
+// backend
+/**
+ * @brief Retrieves the name associated with the CANN backend.
+ *
+ * This function returns the name assigned to the CANN backend, which is stored
+ * in the context of the provided backend structure.
+ *
+ * @param backend Pointer to the CANN backend structure.
+ * @return A pointer to a constant string representing the backend name.
+ */
+GGML_CALL static const char* ggml_backend_cann_name(ggml_backend_t backend) {
+    ggml_backend_cann_context* cann_ctx =
+        (ggml_backend_cann_context*)backend->context;
+
+    return cann_ctx->name.c_str();
+}
+
+/**
+ * @brief Frees resources associated with the CANN backend.
+ *
+ * This function releases resources associated with the CANN backend context
+ * and resets the device associated with the backend to its initial state.
+ *
+ * @param backend Pointer to the CANN backend structure to be freed.
+ */
+GGML_CALL static void ggml_backend_cann_free(ggml_backend_t backend) {
+    ggml_backend_cann_context* cann_ctx =
+        (ggml_backend_cann_context*)backend->context;
+    ACL_CHECK(aclrtSynchronizeDevice());
+    ACL_CHECK(aclrtResetDevice(cann_ctx->device));
+
+    // finalize when last backend freed.
+    if (cann_ctx->device == ggml_backend_cann_get_device_count() - 1) {
+        ACL_CHECK(aclFinalize());
+    }
+
+    delete cann_ctx;
+    delete backend;
+}
+
+/**
+ * @brief Retrieves the default buffer type associated with the CANN backend.
+ *
+ * This function returns the buffer type specific to the device associated
+ * with the CANN backend. It is used to allocate buffers for computations
+ * performed by the backend.
+ *
+ * @param backend Pointer to the CANN backend structure.
+ * @return Pointer to the buffer type structure for the CANN backend.
+ */
+GGML_CALL static ggml_backend_buffer_type_t
+ggml_backend_cann_get_default_buffer_type(ggml_backend_t backend) {
+    ggml_backend_cann_context* cann_ctx =
+        (ggml_backend_cann_context*)backend->context;
+
+    return ggml_backend_cann_buffer_type(cann_ctx->device);
+}
+
+/**
+ * @brief Sets tensor data asynchronously in the CANN backend.
+ *
+ * This function asynchronously sets tensor data in the CANN backend. Depending
+ * on the tensor type, it may perform data transformations before copying data
+ * to the device.
+ *
+ * @param backend Pointer to the CANN backend structure.
+ * @param tensor Pointer to the tensor structure to set data for.
+ * @param data Pointer to the host data to copy to the tensor.
+ * @param offset Offset in bytes within the host data.
+ * @param size Size of the data to copy in bytes.
+ */
+GGML_CALL static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend,
+                                                         ggml_tensor *tensor,
+                                                         const void *data,
+                                                         size_t offset,
+                                                         size_t size) {
+    ggml_backend_cann_context *cann_ctx =
+        (ggml_backend_cann_context *)backend->context;
+
+    if (!need_transform(tensor->type)) {
+        ACL_CHECK(aclrtMemcpyAsync((char *)tensor->data + offset, size, data,
+                                   size, ACL_MEMCPY_HOST_TO_DEVICE,
+                                   cann_ctx->stream()));
+    } else {
+        void *transform_buffer = malloc(size);
+        ggml_backend_cann_transform(tensor, data, transform_buffer);
+
+#ifndef NDEBUG
+        void *check_buffer = malloc(size);
+        ggml_backend_cann_transform_back(tensor, transform_buffer,
+                                         check_buffer);
+        GGML_ASSERT(memcmp(data, check_buffer, size));
+        free(check_buffer);
+#endif
+        ACL_CHECK(aclrtMemcpyAsync(
+            (char *)tensor->data + offset, size, transform_buffer, size,
+            ACL_MEMCPY_HOST_TO_DEVICE, cann_ctx->stream()));
+        ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
+        free(transform_buffer);
+    }
+}
+
+GGML_CALL static void ggml_backend_cann_get_tensor_async(
+    ggml_backend_t backend, const ggml_tensor *tensor, void *data,
+    size_t offset, size_t size) {
+    ggml_backend_cann_context *cann_ctx =
+        (ggml_backend_cann_context *)backend->context;
+    ggml_backend_buffer_t buf =
+        tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
+
+    GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) &&
+                "unsupported buffer type");
+
+    if (!need_transform(tensor->type)) {
+        ACL_CHECK(aclrtMemcpyAsync(data, size, (char *)tensor->data + offset,
+                                   size, ACL_MEMCPY_DEVICE_TO_HOST,
+                                   cann_ctx->stream()));
+    } else {
+        void *transform_buffer = malloc(size);
+        ACL_CHECK(aclrtMemcpyAsync(
+            transform_buffer, size, (char *)tensor->data + offset, size,
+            ACL_MEMCPY_DEVICE_TO_HOST, cann_ctx->stream()));
+        ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
+        ggml_backend_cann_transform_back(tensor, transform_buffer, data);
+        free(transform_buffer);
+    }
+}
+
+/**
+ * @brief Asynchronously copies tensor data between CANN backends.
+ *
+ * This function copies tensor data asynchronously between two CANN backends. It
+ * checks if both tensors reside in CANN buffers and whether the devices support
+ * peer-to-peer access for direct copying. If not, it returns false.
+ *
+ * @param backend_src Pointer to the source CANN backend structure.
+ * @param backend_dst Pointer to the destination CANN backend structure.
+ * @param src Pointer to the source tensor to copy data from.
+ * @param dst Pointer to the destination tensor to copy data to.
+ * @return true if the copy operation succeeds, false otherwise.
+ */
+GGML_CALL static bool ggml_backend_cann_cpy_tensor_async(
+    ggml_backend_t backend_src, ggml_backend_t backend_dst,
+    const ggml_tensor* src, ggml_tensor* dst) {
+    GGML_ASSERT(ggml_backend_is_cann(backend_src) ||
+                ggml_backend_is_cann(backend_dst));
+
+    if (!ggml_backend_buffer_is_cann(src->buffer) ||
+        !ggml_backend_buffer_is_cann(dst->buffer)) {
+        return false;
+    }
+
+    ggml_backend_buffer_t buf_src =
+        src->view_src ? src->view_src->buffer : src->buffer;
+    ggml_backend_buffer_t buf_dst =
+        dst->view_src ? dst->view_src->buffer : dst->buffer;
+
+    ggml_backend_cann_context* cann_ctx_src =
+        (ggml_backend_cann_context*)backend_src->context;
+    ggml_backend_cann_context* cann_ctx_dst =
+        (ggml_backend_cann_context*)backend_dst->context;
+
+    size_t copy_size = ggml_nbytes(dst);
+    if (backend_src != backend_dst) {
+        ggml_backend_cann_buffer_context* buf_ctx_src =
+            (ggml_backend_cann_buffer_context*)buf_src->context;
+        ggml_backend_cann_buffer_context* buf_ctx_dst =
+            (ggml_backend_cann_buffer_context*)buf_dst->context;
+
+        GGML_ASSERT(cann_ctx_src->device == buf_ctx_src->device);
+        GGML_ASSERT(cann_ctx_dst->device == buf_ctx_dst->device);
+
+        int32_t canAccessPeer = 0;
+        ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, cann_ctx_src->device,
+                                           cann_ctx_dst->device));
+        if (!canAccessPeer) {
+            return false;
+        }
+
+        // need open both directions for memcpyasync between devices.
+        ggml_cann_set_device(cann_ctx_dst->device);
+        ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_src->device, 0));
+        ggml_cann_set_device(cann_ctx_src->device);
+        ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0));
+
+        ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
+                                   ACL_MEMCPY_DEVICE_TO_DEVICE,
+                                   cann_ctx_src->stream()));
+
+        //TODO: workaround for Event didn`t work here.
+        aclrtSynchronizeStream(cann_ctx_src->stream());
+    } else {
+        // src and dst are on the same backend
+        ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
+                                   ACL_MEMCPY_DEVICE_TO_DEVICE,
+                                   cann_ctx_dst->stream()));
+    }
+
+    return true;
+}
+
+/**
+ * @brief Synchronizes a CANN backend.
+ *
+ * This function synchronizes the specified CANN backend by waiting for all
+ * operations in its associated stream to complete.
+ *
+ * @param backend Pointer to the CANN backend structure to synchronize.
+ */
+GGML_CALL static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
+    ggml_backend_cann_context* cann_ctx =
+        (ggml_backend_cann_context*)backend->context;
+
+    ggml_cann_set_device(cann_ctx->device);
+
+    ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
+}
+
+/**
+ * @brief Computes a computational graph using a CANN backend.
+ *
+ * This function computes the operations defined in the computational graph
+ * using the specified CANN backend.
+ *
+ * @param backend Pointer to the CANN backend structure to use for computation.
+ * @param cgraph Pointer to the computational graph structure containing nodes
+ *               representing operations to be computed.
+ * @return enum ggml_status Returns GGML_STATUS_SUCCESS if computation
+ *         completes successfully, otherwise an appropriate error status.
+ */
+GGML_CALL static enum ggml_status ggml_backend_cann_graph_compute(
+    ggml_backend_t backend, ggml_cgraph* cgraph) {
+    ggml_backend_cann_context* cann_ctx =
+        (ggml_backend_cann_context*)backend->context;
+
+    ggml_cann_set_device(cann_ctx->device);
+
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        ggml_tensor* node = cgraph->nodes[i];
+
+        if (ggml_is_empty(node) || node->op == GGML_OP_NONE) {
+            continue;
+        }
+
+        bool ok = ggml_cann_compute_forward(*cann_ctx, node);
+
+        if (!ok) {
+            GGML_CANN_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__,
+                    node->name, ggml_op_name(node->op));
+        }
+        GGML_ASSERT(ok);
+    }
+
+    return GGML_STATUS_SUCCESS;
+}
+
+/**
+ * @brief Checks if the CANN backend supports a specific operation.
+ *
+ * This function checks whether the specified operation is supported by the
+ * CANN backend.
+ *
+ * @param backend Pointer to the CANN backend structure to check support for
+ *                the operation.
+ * @param op Pointer to the tensor representing the operation to check.
+ * @return bool Returns true if the operation is supported by the backend,
+ *              otherwise false.
+ */
+GGML_CALL static bool ggml_backend_cann_supports_op(ggml_backend_t backend,
+                                                    const ggml_tensor* op) {
+    switch (op->op) {
+        case GGML_OP_UNARY:
+            switch (ggml_get_unary_op(op)) {
+                case GGML_UNARY_OP_GELU:
+                case GGML_UNARY_OP_SILU:
+                case GGML_UNARY_OP_RELU:
+                case GGML_UNARY_OP_HARDSIGMOID:
+                case GGML_UNARY_OP_HARDSWISH:
+                case GGML_UNARY_OP_GELU_QUICK:
+                case GGML_UNARY_OP_TANH:
+                    return true;
+                default:
+                    return false;
+            }
+        case GGML_OP_MUL_MAT: {
+            switch (op->src[0]->type) {
+                case GGML_TYPE_F16:
+                case GGML_TYPE_F32:
+                case GGML_TYPE_Q8_0:
+                    // TODO: fix me
+                    // Current groupsize should not be greater than k-1 in
+                    // aclnnWeightQuantBatchMatmulV2GetWorkspaceSize().
+                case GGML_TYPE_Q4_0:
+                    return true;
+                default:
+                    return false;
+            }
+        }
+        case GGML_OP_MUL_MAT_ID:
+            return false;
+        // embedding
+        case GGML_OP_GET_ROWS: {
+            switch (op->src[0]->type) {
+                case GGML_TYPE_F32:
+                case GGML_TYPE_F16:
+                case GGML_TYPE_Q4_0:
+                case GGML_TYPE_Q8_0:
+                    return true;
+                default:
+                    return false;
+            }
+        } break;
+        case GGML_OP_CPY: {
+            switch (op->type) {
+                case GGML_TYPE_F32:
+                case GGML_TYPE_F16:
+                case GGML_TYPE_Q8_0:
+                case GGML_TYPE_Q4_0:
+                    return true;
+                default:
+                    return false;
+            }
+        }
+        case GGML_OP_DUP:
+        case GGML_OP_REPEAT:
+        case GGML_OP_CONCAT:
+        case GGML_OP_NONE:
+        case GGML_OP_RESHAPE:
+        case GGML_OP_VIEW:
+        case GGML_OP_PERMUTE:
+        case GGML_OP_TRANSPOSE:
+        case GGML_OP_NORM:
+        case GGML_OP_ADD:
+        case GGML_OP_MUL:
+        case GGML_OP_DIV:
+        case GGML_OP_RMS_NORM:
+        case GGML_OP_SCALE:
+        case GGML_OP_SQR:
+        case GGML_OP_CLAMP:
+        case GGML_OP_CONT:
+        case GGML_OP_DIAG_MASK_INF:
+        case GGML_OP_SOFT_MAX:
+        case GGML_OP_ROPE:
+        case GGML_OP_IM2COL:
+        case GGML_OP_POOL_2D:
+        case GGML_OP_SUM_ROWS:
+        case GGML_OP_ARGSORT:
+        case GGML_OP_ACC:
+        case GGML_OP_GROUP_NORM:
+        case GGML_OP_UPSCALE:
+        case GGML_OP_PAD:
+        case GGML_OP_ARANGE:
+        case GGML_OP_TIMESTEP_EMBEDDING:
+        case GGML_OP_LEAKY_RELU:
+            return true;
+        default:
+            return false;
+    }
+
+    GGML_UNUSED(backend);
+}
+
+/**
+ * @brief Checks if the backend buffer type is associated with the CANN backend.
+ *
+ * This function checks whether the provided backend buffer type is associated
+ * with the CANN backend based on the comparison of its name retrieval function
+ * pointer.
+ *
+ * @param buft Pointer to the backend buffer type to check.
+ * @return bool Returns true if the buffer type is associated with the CANN
+ * backend, otherwise false.
+ */
+static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
+    return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
+}
+
+/**
+ * @brief Checks if the CANN backend supports a specific backend buffer type.
+ *
+ * This function determines whether the CANN backend supports the given backend
+ * buffer type by comparing the device context of the backend and buffer type.
+ * It returns true if the devices are same between the backend context and
+ * buffer type context.
+ *
+ * @param backend Pointer to the CANN backend.
+ * @param buft Pointer to the backend buffer type to check.
+ * @return bool Returns true if the CANN backend supports the buffer type,
+ *              otherwise false.
+ */
+GGML_CALL static bool ggml_backend_cann_supports_buft(
+    ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
+    if (ggml_backend_buft_is_cann(buft)) {
+        ggml_backend_cann_context * cann_ctx =
+                        (ggml_backend_cann_context *)backend->context;
+        ggml_backend_cann_buffer_type_context * buft_ctx =
+                        (ggml_backend_cann_buffer_type_context *)buft->context;
+        return buft_ctx->device == cann_ctx->device;
+    }
+    return false;
+}
+
+/**
+ * @brief Determines if a tensor operation should be offloaded to the CANN
+ * backend.
+ *
+ * This function checks if a given tensor operation should be offloaded to the
+ * CANN backend based on the operation type and the size of the tensor. It
+ * returns true if the second dimension (ne[1]) of the tensor is greater than or
+ * equal to the minimum batch size and the operation is not GGML_OP_GET_ROWS.
+ *
+ * @param backend Pointer to the CANN backend.
+ * @param op Pointer to the tensor operation to check.
+ * @return bool Returns true if the operation should be offloaded, otherwise
+ * false.
+ */
+GGML_CALL static bool ggml_backend_cann_offload_op(ggml_backend_t backend,
+                                                   const ggml_tensor* op) {
+    const int min_batch_size = 32;
+    GGML_UNUSED(backend);
+
+    return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
+}
+
+/**
+ * @brief Creates a new event for the CANN backend.
+ *
+ * This function initializes a new event for the CANN backend by setting the
+ * device and creating an ACL runtime event. The created event is then wrapped
+ * in a ggml_backend_event structure and returned.
+ *
+ * @param backend Pointer to the CANN backend.
+ * @return ggml_backend_event_t Returns a pointer to the new event structure.
+ */
+static ggml_backend_event_t ggml_backend_cann_event_new(
+    ggml_backend_t backend) {
+    ggml_backend_cann_context* cann_ctx =
+        (ggml_backend_cann_context*)backend->context;
+
+    ggml_cann_set_device(cann_ctx->device);
+
+    aclrtEvent event;
+    ACL_CHECK(aclrtCreateEvent(&event));
+
+    return new ggml_backend_event{
+        /* .backend = */ backend,
+        /* .context = */ event,
+    };
+}
+
+/**
+ * @brief Frees a CANN backend event.
+ *
+ * This function destroys the ACL runtime event associated with the given CANN
+ * backend event and then deletes the event structure itself.
+ *
+ * @param event Pointer to the event structure to be freed.
+ */
+static void ggml_backend_cann_event_free(ggml_backend_event_t event) {
+    ACL_CHECK(aclrtDestroyEvent((aclrtEvent)event->context));
+
+    delete event;
+}
+
+/**
+ * @brief Records an event on the CANN backend stream.
+ *
+ * This function records the given event on the ACL runtime stream associated
+ * with the backend context.
+ *
+ * @param event Pointer to the event structure to be recorded.
+ */
+static void ggml_backend_cann_event_record(ggml_backend_event_t event) {
+    ggml_backend_cann_context* cann_ctx =
+        (ggml_backend_cann_context*)event->backend->context;
+
+    ACL_CHECK(aclrtRecordEvent((aclrtEvent)event->context, cann_ctx->stream()));
+}
+
+/**
+ * @brief Waits for a recorded event to complete on the CANN backend stream.
+ *
+ * This function makes the given backend wait for the event to complete on its
+ * ACL runtime stream.
+ *
+ * @param backend Pointer to the backend structure.
+ * @param event Pointer to the event structure that the backend needs to wait
+ * for.
+ */
+static void ggml_backend_cann_event_wait(ggml_backend_t backend,
+                                         ggml_backend_event_t event) {
+    ggml_backend_cann_context* cann_ctx =
+        (ggml_backend_cann_context*)backend->context;
+
+    if (ggml_backend_is_cann(event->backend)) {
+        ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(),
+                                       (aclrtEvent)event->context));
+    } else {
+        GGML_ABORT("fatal error");
+    }
+}
+
+/**
+ * @brief Synchronizes the given event on the CANN backend.
+ *
+ * This function waits for the specified event to complete on the ACL runtime.
+ *
+ * @param event Pointer to the event structure to be synchronized.
+ */
+static void ggml_backend_cann_event_synchronize(ggml_backend_event_t event) {
+    ACL_CHECK(aclrtSynchronizeEvent((aclrtEvent)event->context));
+}
+
+/**
+ * @brief Structure defining the interface for the CANN backend.
+ *
+ * This structure contains function pointers for various operations
+ * supported by the CANN backend, including name retrieval, memory
+ * management, tensor operations, synchronization, and event handling.
+ */
+static ggml_backend_i ggml_backend_cann_interface = {
+    /* .get_name                = */ ggml_backend_cann_name,
+    /* .free                    = */ ggml_backend_cann_free,
+    /* .get_default_buffer_type = */ ggml_backend_cann_get_default_buffer_type,
+    /* .set_tensor_async        = */ ggml_backend_cann_set_tensor_async,
+    /* .get_tensor_async        = */ ggml_backend_cann_get_tensor_async,
+    /* .cpy_tensor_async        = */ ggml_backend_cann_cpy_tensor_async,
+    /* .synchronize             = */ ggml_backend_cann_synchronize,
+    /* .graph_plan_create       = */ NULL,
+    /* .graph_plan_free         = */ NULL,
+    /* .graph_plan_update       = */ NULL,
+    /* .graph_plan_compute      = */ NULL,
+    /* .graph_compute           = */ ggml_backend_cann_graph_compute,
+    /* .supports_op             = */ ggml_backend_cann_supports_op,
+    /* .supports_buft           = */ ggml_backend_cann_supports_buft,
+    /* .offload_op              = */ ggml_backend_cann_offload_op,
+    /* .event_new               = */ ggml_backend_cann_event_new,
+    /* .event_free              = */ ggml_backend_cann_event_free,
+    /* .event_record            = */ ggml_backend_cann_event_record,
+    /* .event_wait              = */ ggml_backend_cann_event_wait,
+    /* .event_synchronize       = */ ggml_backend_cann_event_synchronize,
+};
+
+/**
+ * @brief Return the hardcoded GUID for the CANN backend.
+ *
+ * This function returns a static GUID which uniquely identifies the CANN
+ * backend.
+ *
+ * @return A pointer to the static GUID.
+ */
+static ggml_guid_t ggml_backend_cann_guid() {
+    static ggml_guid guid = {0xa1, 0x94, 0xaf, 0xac, 0xbd, 0x4f, 0x47, 0x34,
+                             0xbe, 0x1a, 0x9e, 0x71, 0x1f, 0x9e, 0xed, 0x64};
+    return &guid;
+}
+
+GGML_CALL ggml_backend_t ggml_backend_cann_init(int32_t device) {
+    aclInit(nullptr);
+    if (device < 0 || device >= ggml_backend_cann_get_device_count()) {
+        GGML_CANN_LOG_ERROR("%s: error: invalid device %d\n", __func__, device);
+        return nullptr;
+    }
+
+    ggml_backend_cann_context* ctx = new ggml_backend_cann_context(device);
+    if (ctx == nullptr) {
+        GGML_CANN_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
+        return nullptr;
+    }
+
+    ggml_backend_t cann_backend =
+        new ggml_backend{/* .guid      = */ ggml_backend_cann_guid(),
+                         /* .interface = */ ggml_backend_cann_interface,
+                         /* .context   = */ ctx};
+
+    return cann_backend;
+}
+
+GGML_CALL bool ggml_backend_is_cann(ggml_backend_t backend) {
+    return backend != NULL &&
+           ggml_guid_matches(backend->guid, ggml_backend_cann_guid());
+}
+
+GGML_CALL int32_t ggml_backend_cann_get_device_count() {
+    return ggml_cann_info().device_count;
+}
+
+GGML_CALL void ggml_backend_cann_get_device_description(
+    int32_t device, char* description, size_t description_size) {
+    ggml_cann_set_device(device);
+    const char* soc_name = aclrtGetSocName();
+    snprintf(description, description_size, "%s", soc_name);
+}
+
+GGML_CALL void ggml_backend_cann_get_device_memory(int32_t device, size_t* free,
+                                                   size_t* total) {
+    ggml_cann_set_device(device);
+    ACL_CHECK(aclrtGetMemInfo(ACL_HBM_MEM, free, total));
+}
+
+// backend registry
+/**
+ * @brief Initializes a CANN backend based on the provided parameters.
+ *
+ * This function initializes a CANN backend using the device index and then
+ * initializes the backend using `ggml_backend_cann_init`.
+ *
+ * @param params Parameters for initialization (unused in this implementation).
+ * @param user_data User data containing the device index to initialize the
+ * backend.
+ * @return ggml_backend_t The initialized CANN backend.
+ */
+GGML_CALL static ggml_backend_t ggml_backend_reg_cann_init(const char* params,
+                                                           void* user_data) {
+    ggml_backend_t cann_backend =
+        ggml_backend_cann_init((int)(intptr_t)user_data);
+    return cann_backend;
+
+    GGML_UNUSED(params);
+}
+
+extern "C" GGML_CALL int ggml_backend_cann_reg_devices();
+
+/**
+ * @brief Registers CANN (Ascend) devices as backend options.
+ *
+ * This function initializes ACL, retrieves the number of available CANN
+ * devices, and registers each device as a backend option using
+ * `ggml_backend_register`. Each device is given a unique name based on
+ * `GGML_CANN_NAME` followed by its index.
+ *
+ * @return int The number of CANN devices registered.
+ */
+GGML_CALL int ggml_backend_cann_reg_devices() {
+    uint32_t device_count = ggml_backend_cann_get_device_count();
+    // initialization
+    for (uint32_t i = 0; i < device_count; i++) {
+        char name[128];
+        snprintf(name, sizeof(name), "CANN%d", i);
+        ggml_backend_register(name, ggml_backend_reg_cann_init,
+                              ggml_backend_cann_buffer_type(i),
+                              (void*)(intptr_t)i);
+    }
+    return device_count;
+}
diff --git a/src/ggml-cann/Doxyfile b/src/ggml-cann/Doxyfile
new file mode 100644 (file)
index 0000000..2b009e8
--- /dev/null
@@ -0,0 +1,2579 @@
+# Doxyfile 1.8.17
+
+# This file describes the settings to be used by the documentation system
+# doxygen (www.doxygen.org) for a project.
+#
+# All text after a double hash (##) is considered a comment and is placed in
+# front of the TAG it is preceding.
+#
+# All text after a single hash (#) is considered a comment and will be ignored.
+# The format is:
+# TAG = value [value, ...]
+# For lists, items can also be appended using:
+# TAG += value [value, ...]
+# Values that contain spaces should be placed between quotes (\" \").
+
+#---------------------------------------------------------------------------
+# Project related configuration options
+#---------------------------------------------------------------------------
+
+# This tag specifies the encoding used for all characters in the configuration
+# file that follow. The default is UTF-8 which is also the encoding used for all
+# text before the first occurrence of this tag. Doxygen uses libiconv (or the
+# iconv built into libc) for the transcoding. See
+# https://www.gnu.org/software/libiconv/ for the list of possible encodings.
+# The default value is: UTF-8.
+
+DOXYFILE_ENCODING      = UTF-8
+
+# The PROJECT_NAME tag is a single word (or a sequence of words surrounded by
+# double-quotes, unless you are using Doxywizard) that should identify the
+# project for which the documentation is generated. This name is used in the
+# title of most generated pages and in a few other places.
+# The default value is: My Project.
+
+PROJECT_NAME           = "llama.cpp"
+
+# The PROJECT_NUMBER tag can be used to enter a project or revision number. This
+# could be handy for archiving the generated documentation or if some version
+# control system is used.
+
+PROJECT_NUMBER         =
+
+# Using the PROJECT_BRIEF tag one can provide an optional one line description
+# for a project that appears at the top of each page and should give viewer a
+# quick idea about the purpose of the project. Keep the description short.
+
+PROJECT_BRIEF          = "llama inference engine"
+
+# With the PROJECT_LOGO tag one can specify a logo or an icon that is included
+# in the documentation. The maximum height of the logo should not exceed 55
+# pixels and the maximum width should not exceed 200 pixels. Doxygen will copy
+# the logo to the output directory.
+
+PROJECT_LOGO           =
+
+# The OUTPUT_DIRECTORY tag is used to specify the (relative or absolute) path
+# into which the generated documentation will be written. If a relative path is
+# entered, it will be relative to the location where doxygen was started. If
+# left blank the current directory will be used.
+
+OUTPUT_DIRECTORY       = docs
+
+# If the CREATE_SUBDIRS tag is set to YES then doxygen will create 4096 sub-
+# directories (in 2 levels) under the output directory of each output format and
+# will distribute the generated files over these directories. Enabling this
+# option can be useful when feeding doxygen a huge amount of source files, where
+# putting all generated files in the same directory would otherwise causes
+# performance problems for the file system.
+# The default value is: NO.
+
+CREATE_SUBDIRS         = NO
+
+# If the ALLOW_UNICODE_NAMES tag is set to YES, doxygen will allow non-ASCII
+# characters to appear in the names of generated files. If set to NO, non-ASCII
+# characters will be escaped, for example _xE3_x81_x84 will be used for Unicode
+# U+3044.
+# The default value is: NO.
+
+ALLOW_UNICODE_NAMES    = NO
+
+# The OUTPUT_LANGUAGE tag is used to specify the language in which all
+# documentation generated by doxygen is written. Doxygen will use this
+# information to generate all constant output in the proper language.
+# Possible values are: Afrikaans, Arabic, Armenian, Brazilian, Catalan, Chinese,
+# Chinese-Traditional, Croatian, Czech, Danish, Dutch, English (United States),
+# Esperanto, Farsi (Persian), Finnish, French, German, Greek, Hungarian,
+# Indonesian, Italian, Japanese, Japanese-en (Japanese with English messages),
+# Korean, Korean-en (Korean with English messages), Latvian, Lithuanian,
+# Macedonian, Norwegian, Persian (Farsi), Polish, Portuguese, Romanian, Russian,
+# Serbian, Serbian-Cyrillic, Slovak, Slovene, Spanish, Swedish, Turkish,
+# Ukrainian and Vietnamese.
+# The default value is: English.
+
+OUTPUT_LANGUAGE        = English
+
+# The OUTPUT_TEXT_DIRECTION tag is used to specify the direction in which all
+# documentation generated by doxygen is written. Doxygen will use this
+# information to generate all generated output in the proper direction.
+# Possible values are: None, LTR, RTL and Context.
+# The default value is: None.
+
+OUTPUT_TEXT_DIRECTION  = None
+
+# If the BRIEF_MEMBER_DESC tag is set to YES, doxygen will include brief member
+# descriptions after the members that are listed in the file and class
+# documentation (similar to Javadoc). Set to NO to disable this.
+# The default value is: YES.
+
+BRIEF_MEMBER_DESC      = YES
+
+# If the REPEAT_BRIEF tag is set to YES, doxygen will prepend the brief
+# description of a member or function before the detailed description
+#
+# Note: If both HIDE_UNDOC_MEMBERS and BRIEF_MEMBER_DESC are set to NO, the
+# brief descriptions will be completely suppressed.
+# The default value is: YES.
+
+REPEAT_BRIEF           = YES
+
+# This tag implements a quasi-intelligent brief description abbreviator that is
+# used to form the text in various listings. Each string in this list, if found
+# as the leading text of the brief description, will be stripped from the text
+# and the result, after processing the whole list, is used as the annotated
+# text. Otherwise, the brief description is used as-is. If left blank, the
+# following values are used ($name is automatically replaced with the name of
+# the entity):The $name class, The $name widget, The $name file, is, provides,
+# specifies, contains, represents, a, an and the.
+
+ABBREVIATE_BRIEF       = "The $name class" \
+                         "The $name widget" \
+                         "The $name file" \
+                         is \
+                         provides \
+                         specifies \
+                         contains \
+                         represents \
+                         a \
+                         an \
+                         the
+
+# If the ALWAYS_DETAILED_SEC and REPEAT_BRIEF tags are both set to YES then
+# doxygen will generate a detailed section even if there is only a brief
+# description.
+# The default value is: NO.
+
+ALWAYS_DETAILED_SEC    = NO
+
+# If the INLINE_INHERITED_MEMB tag is set to YES, doxygen will show all
+# inherited members of a class in the documentation of that class as if those
+# members were ordinary class members. Constructors, destructors and assignment
+# operators of the base classes will not be shown.
+# The default value is: NO.
+
+INLINE_INHERITED_MEMB  = NO
+
+# If the FULL_PATH_NAMES tag is set to YES, doxygen will prepend the full path
+# before files name in the file list and in the header files. If set to NO the
+# shortest path that makes the file name unique will be used
+# The default value is: YES.
+
+FULL_PATH_NAMES        = YES
+
+# The STRIP_FROM_PATH tag can be used to strip a user-defined part of the path.
+# Stripping is only done if one of the specified strings matches the left-hand
+# part of the path. The tag can be used to show relative paths in the file list.
+# If left blank the directory from which doxygen is run is used as the path to
+# strip.
+#
+# Note that you can specify absolute paths here, but also relative paths, which
+# will be relative from the directory where doxygen is started.
+# This tag requires that the tag FULL_PATH_NAMES is set to YES.
+
+STRIP_FROM_PATH        =
+
+# The STRIP_FROM_INC_PATH tag can be used to strip a user-defined part of the
+# path mentioned in the documentation of a class, which tells the reader which
+# header file to include in order to use a class. If left blank only the name of
+# the header file containing the class definition is used. Otherwise one should
+# specify the list of include paths that are normally passed to the compiler
+# using the -I flag.
+
+STRIP_FROM_INC_PATH    =
+
+# If the SHORT_NAMES tag is set to YES, doxygen will generate much shorter (but
+# less readable) file names. This can be useful is your file systems doesn't
+# support long names like on DOS, Mac, or CD-ROM.
+# The default value is: NO.
+
+SHORT_NAMES            = NO
+
+# If the JAVADOC_AUTOBRIEF tag is set to YES then doxygen will interpret the
+# first line (until the first dot) of a Javadoc-style comment as the brief
+# description. If set to NO, the Javadoc-style will behave just like regular Qt-
+# style comments (thus requiring an explicit @brief command for a brief
+# description.)
+# The default value is: NO.
+
+JAVADOC_AUTOBRIEF      = NO
+
+# If the JAVADOC_BANNER tag is set to YES then doxygen will interpret a line
+# such as
+# /***************
+# as being the beginning of a Javadoc-style comment "banner". If set to NO, the
+# Javadoc-style will behave just like regular comments and it will not be
+# interpreted by doxygen.
+# The default value is: NO.
+
+JAVADOC_BANNER         = NO
+
+# If the QT_AUTOBRIEF tag is set to YES then doxygen will interpret the first
+# line (until the first dot) of a Qt-style comment as the brief description. If
+# set to NO, the Qt-style will behave just like regular Qt-style comments (thus
+# requiring an explicit \brief command for a brief description.)
+# The default value is: NO.
+
+QT_AUTOBRIEF           = NO
+
+# The MULTILINE_CPP_IS_BRIEF tag can be set to YES to make doxygen treat a
+# multi-line C++ special comment block (i.e. a block of //! or /// comments) as
+# a brief description. This used to be the default behavior. The new default is
+# to treat a multi-line C++ comment block as a detailed description. Set this
+# tag to YES if you prefer the old behavior instead.
+#
+# Note that setting this tag to YES also means that rational rose comments are
+# not recognized any more.
+# The default value is: NO.
+
+MULTILINE_CPP_IS_BRIEF = NO
+
+# If the INHERIT_DOCS tag is set to YES then an undocumented member inherits the
+# documentation from any documented member that it re-implements.
+# The default value is: YES.
+
+INHERIT_DOCS           = YES
+
+# If the SEPARATE_MEMBER_PAGES tag is set to YES then doxygen will produce a new
+# page for each member. If set to NO, the documentation of a member will be part
+# of the file/class/namespace that contains it.
+# The default value is: NO.
+
+SEPARATE_MEMBER_PAGES  = NO
+
+# The TAB_SIZE tag can be used to set the number of spaces in a tab. Doxygen
+# uses this value to replace tabs by spaces in code fragments.
+# Minimum value: 1, maximum value: 16, default value: 4.
+
+TAB_SIZE               = 4
+
+# This tag can be used to specify a number of aliases that act as commands in
+# the documentation. An alias has the form:
+# name=value
+# For example adding
+# "sideeffect=@par Side Effects:\n"
+# will allow you to put the command \sideeffect (or @sideeffect) in the
+# documentation, which will result in a user-defined paragraph with heading
+# "Side Effects:". You can put \n's in the value part of an alias to insert
+# newlines (in the resulting output). You can put ^^ in the value part of an
+# alias to insert a newline as if a physical newline was in the original file.
+# When you need a literal { or } or , in the value part of an alias you have to
+# escape them by means of a backslash (\), this can lead to conflicts with the
+# commands \{ and \} for these it is advised to use the version @{ and @} or use
+# a double escape (\\{ and \\})
+
+ALIASES                =
+
+# This tag can be used to specify a number of word-keyword mappings (TCL only).
+# A mapping has the form "name=value". For example adding "class=itcl::class"
+# will allow you to use the command class in the itcl::class meaning.
+
+TCL_SUBST              =
+
+# Set the OPTIMIZE_OUTPUT_FOR_C tag to YES if your project consists of C sources
+# only. Doxygen will then generate output that is more tailored for C. For
+# instance, some of the names that are used will be different. The list of all
+# members will be omitted, etc.
+# The default value is: NO.
+
+OPTIMIZE_OUTPUT_FOR_C  = NO
+
+# Set the OPTIMIZE_OUTPUT_JAVA tag to YES if your project consists of Java or
+# Python sources only. Doxygen will then generate output that is more tailored
+# for that language. For instance, namespaces will be presented as packages,
+# qualified scopes will look different, etc.
+# The default value is: NO.
+
+OPTIMIZE_OUTPUT_JAVA   = NO
+
+# Set the OPTIMIZE_FOR_FORTRAN tag to YES if your project consists of Fortran
+# sources. Doxygen will then generate output that is tailored for Fortran.
+# The default value is: NO.
+
+OPTIMIZE_FOR_FORTRAN   = NO
+
+# Set the OPTIMIZE_OUTPUT_VHDL tag to YES if your project consists of VHDL
+# sources. Doxygen will then generate output that is tailored for VHDL.
+# The default value is: NO.
+
+OPTIMIZE_OUTPUT_VHDL   = NO
+
+# Set the OPTIMIZE_OUTPUT_SLICE tag to YES if your project consists of Slice
+# sources only. Doxygen will then generate output that is more tailored for that
+# language. For instance, namespaces will be presented as modules, types will be
+# separated into more groups, etc.
+# The default value is: NO.
+
+OPTIMIZE_OUTPUT_SLICE  = NO
+
+# Doxygen selects the parser to use depending on the extension of the files it
+# parses. With this tag you can assign which parser to use for a given
+# extension. Doxygen has a built-in mapping, but you can override or extend it
+# using this tag. The format is ext=language, where ext is a file extension, and
+# language is one of the parsers supported by doxygen: IDL, Java, JavaScript,
+# Csharp (C#), C, C++, D, PHP, md (Markdown), Objective-C, Python, Slice,
+# Fortran (fixed format Fortran: FortranFixed, free formatted Fortran:
+# FortranFree, unknown formatted Fortran: Fortran. In the later case the parser
+# tries to guess whether the code is fixed or free formatted code, this is the
+# default for Fortran type files), VHDL, tcl. For instance to make doxygen treat
+# .inc files as Fortran files (default is PHP), and .f files as C (default is
+# Fortran), use: inc=Fortran f=C.
+#
+# Note: For files without extension you can use no_extension as a placeholder.
+#
+# Note that for custom extensions you also need to set FILE_PATTERNS otherwise
+# the files are not read by doxygen.
+
+EXTENSION_MAPPING      =
+
+# If the MARKDOWN_SUPPORT tag is enabled then doxygen pre-processes all comments
+# according to the Markdown format, which allows for more readable
+# documentation. See https://daringfireball.net/projects/markdown/ for details.
+# The output of markdown processing is further processed by doxygen, so you can
+# mix doxygen, HTML, and XML commands with Markdown formatting. Disable only in
+# case of backward compatibilities issues.
+# The default value is: YES.
+
+MARKDOWN_SUPPORT       = YES
+
+# When the TOC_INCLUDE_HEADINGS tag is set to a non-zero value, all headings up
+# to that level are automatically included in the table of contents, even if
+# they do not have an id attribute.
+# Note: This feature currently applies only to Markdown headings.
+# Minimum value: 0, maximum value: 99, default value: 5.
+# This tag requires that the tag MARKDOWN_SUPPORT is set to YES.
+
+TOC_INCLUDE_HEADINGS   = 5
+
+# When enabled doxygen tries to link words that correspond to documented
+# classes, or namespaces to their corresponding documentation. Such a link can
+# be prevented in individual cases by putting a % sign in front of the word or
+# globally by setting AUTOLINK_SUPPORT to NO.
+# The default value is: YES.
+
+AUTOLINK_SUPPORT       = YES
+
+# If you use STL classes (i.e. std::string, std::vector, etc.) but do not want
+# to include (a tag file for) the STL sources as input, then you should set this
+# tag to YES in order to let doxygen match functions declarations and
+# definitions whose arguments contain STL classes (e.g. func(std::string);
+# versus func(std::string) {}). This also make the inheritance and collaboration
+# diagrams that involve STL classes more complete and accurate.
+# The default value is: NO.
+
+BUILTIN_STL_SUPPORT    = NO
+
+# If you use Microsoft's C++/CLI language, you should set this option to YES to
+# enable parsing support.
+# The default value is: NO.
+
+CPP_CLI_SUPPORT        = NO
+
+# Set the SIP_SUPPORT tag to YES if your project consists of sip (see:
+# https://www.riverbankcomputing.com/software/sip/intro) sources only. Doxygen
+# will parse them like normal C++ but will assume all classes use public instead
+# of private inheritance when no explicit protection keyword is present.
+# The default value is: NO.
+
+SIP_SUPPORT            = NO
+
+# For Microsoft's IDL there are propget and propput attributes to indicate
+# getter and setter methods for a property. Setting this option to YES will make
+# doxygen to replace the get and set methods by a property in the documentation.
+# This will only work if the methods are indeed getting or setting a simple
+# type. If this is not the case, or you want to show the methods anyway, you
+# should set this option to NO.
+# The default value is: YES.
+
+IDL_PROPERTY_SUPPORT   = YES
+
+# If member grouping is used in the documentation and the DISTRIBUTE_GROUP_DOC
+# tag is set to YES then doxygen will reuse the documentation of the first
+# member in the group (if any) for the other members of the group. By default
+# all members of a group must be documented explicitly.
+# The default value is: NO.
+
+DISTRIBUTE_GROUP_DOC   = NO
+
+# If one adds a struct or class to a group and this option is enabled, then also
+# any nested class or struct is added to the same group. By default this option
+# is disabled and one has to add nested compounds explicitly via \ingroup.
+# The default value is: NO.
+
+GROUP_NESTED_COMPOUNDS = NO
+
+# Set the SUBGROUPING tag to YES to allow class member groups of the same type
+# (for instance a group of public functions) to be put as a subgroup of that
+# type (e.g. under the Public Functions section). Set it to NO to prevent
+# subgrouping. Alternatively, this can be done per class using the
+# \nosubgrouping command.
+# The default value is: YES.
+
+SUBGROUPING            = YES
+
+# When the INLINE_GROUPED_CLASSES tag is set to YES, classes, structs and unions
+# are shown inside the group in which they are included (e.g. using \ingroup)
+# instead of on a separate page (for HTML and Man pages) or section (for LaTeX
+# and RTF).
+#
+# Note that this feature does not work in combination with
+# SEPARATE_MEMBER_PAGES.
+# The default value is: NO.
+
+INLINE_GROUPED_CLASSES = NO
+
+# When the INLINE_SIMPLE_STRUCTS tag is set to YES, structs, classes, and unions
+# with only public data fields or simple typedef fields will be shown inline in
+# the documentation of the scope in which they are defined (i.e. file,
+# namespace, or group documentation), provided this scope is documented. If set
+# to NO, structs, classes, and unions are shown on a separate page (for HTML and
+# Man pages) or section (for LaTeX and RTF).
+# The default value is: NO.
+
+INLINE_SIMPLE_STRUCTS  = NO
+
+# When TYPEDEF_HIDES_STRUCT tag is enabled, a typedef of a struct, union, or
+# enum is documented as struct, union, or enum with the name of the typedef. So
+# typedef struct TypeS {} TypeT, will appear in the documentation as a struct
+# with name TypeT. When disabled the typedef will appear as a member of a file,
+# namespace, or class. And the struct will be named TypeS. This can typically be
+# useful for C code in case the coding convention dictates that all compound
+# types are typedef'ed and only the typedef is referenced, never the tag name.
+# The default value is: NO.
+
+TYPEDEF_HIDES_STRUCT   = NO
+
+# The size of the symbol lookup cache can be set using LOOKUP_CACHE_SIZE. This
+# cache is used to resolve symbols given their name and scope. Since this can be
+# an expensive process and often the same symbol appears multiple times in the
+# code, doxygen keeps a cache of pre-resolved symbols. If the cache is too small
+# doxygen will become slower. If the cache is too large, memory is wasted. The
+# cache size is given by this formula: 2^(16+LOOKUP_CACHE_SIZE). The valid range
+# is 0..9, the default is 0, corresponding to a cache size of 2^16=65536
+# symbols. At the end of a run doxygen will report the cache usage and suggest
+# the optimal cache size from a speed point of view.
+# Minimum value: 0, maximum value: 9, default value: 0.
+
+LOOKUP_CACHE_SIZE      = 0
+
+#---------------------------------------------------------------------------
+# Build related configuration options
+#---------------------------------------------------------------------------
+
+# If the EXTRACT_ALL tag is set to YES, doxygen will assume all entities in
+# documentation are documented, even if no documentation was available. Private
+# class members and static file members will be hidden unless the
+# EXTRACT_PRIVATE respectively EXTRACT_STATIC tags are set to YES.
+# Note: This will also disable the warnings about undocumented members that are
+# normally produced when WARNINGS is set to YES.
+# The default value is: NO.
+
+EXTRACT_ALL            = YES
+
+# If the EXTRACT_PRIVATE tag is set to YES, all private members of a class will
+# be included in the documentation.
+# The default value is: NO.
+
+EXTRACT_PRIVATE        = YES
+
+# If the EXTRACT_PRIV_VIRTUAL tag is set to YES, documented private virtual
+# methods of a class will be included in the documentation.
+# The default value is: NO.
+
+EXTRACT_PRIV_VIRTUAL   = YES
+
+# If the EXTRACT_PACKAGE tag is set to YES, all members with package or internal
+# scope will be included in the documentation.
+# The default value is: NO.
+
+EXTRACT_PACKAGE        = YES
+
+# If the EXTRACT_STATIC tag is set to YES, all static members of a file will be
+# included in the documentation.
+# The default value is: NO.
+
+EXTRACT_STATIC         = YES
+
+# If the EXTRACT_LOCAL_CLASSES tag is set to YES, classes (and structs) defined
+# locally in source files will be included in the documentation. If set to NO,
+# only classes defined in header files are included. Does not have any effect
+# for Java sources.
+# The default value is: YES.
+
+EXTRACT_LOCAL_CLASSES  = YES
+
+# This flag is only useful for Objective-C code. If set to YES, local methods,
+# which are defined in the implementation section but not in the interface are
+# included in the documentation. If set to NO, only methods in the interface are
+# included.
+# The default value is: NO.
+
+EXTRACT_LOCAL_METHODS  = YES
+
+# If this flag is set to YES, the members of anonymous namespaces will be
+# extracted and appear in the documentation as a namespace called
+# 'anonymous_namespace{file}', where file will be replaced with the base name of
+# the file that contains the anonymous namespace. By default anonymous namespace
+# are hidden.
+# The default value is: NO.
+
+EXTRACT_ANON_NSPACES   = NO
+
+# If the HIDE_UNDOC_MEMBERS tag is set to YES, doxygen will hide all
+# undocumented members inside documented classes or files. If set to NO these
+# members will be included in the various overviews, but no documentation
+# section is generated. This option has no effect if EXTRACT_ALL is enabled.
+# The default value is: NO.
+
+HIDE_UNDOC_MEMBERS     = NO
+
+# If the HIDE_UNDOC_CLASSES tag is set to YES, doxygen will hide all
+# undocumented classes that are normally visible in the class hierarchy. If set
+# to NO, these classes will be included in the various overviews. This option
+# has no effect if EXTRACT_ALL is enabled.
+# The default value is: NO.
+
+HIDE_UNDOC_CLASSES     = NO
+
+# If the HIDE_FRIEND_COMPOUNDS tag is set to YES, doxygen will hide all friend
+# declarations. If set to NO, these declarations will be included in the
+# documentation.
+# The default value is: NO.
+
+HIDE_FRIEND_COMPOUNDS  = NO
+
+# If the HIDE_IN_BODY_DOCS tag is set to YES, doxygen will hide any
+# documentation blocks found inside the body of a function. If set to NO, these
+# blocks will be appended to the function's detailed documentation block.
+# The default value is: NO.
+
+HIDE_IN_BODY_DOCS      = NO
+
+# The INTERNAL_DOCS tag determines if documentation that is typed after a
+# \internal command is included. If the tag is set to NO then the documentation
+# will be excluded. Set it to YES to include the internal documentation.
+# The default value is: NO.
+
+INTERNAL_DOCS          = NO
+
+# If the CASE_SENSE_NAMES tag is set to NO then doxygen will only generate file
+# names in lower-case letters. If set to YES, upper-case letters are also
+# allowed. This is useful if you have classes or files whose names only differ
+# in case and if your file system supports case sensitive file names. Windows
+# (including Cygwin) ands Mac users are advised to set this option to NO.
+# The default value is: system dependent.
+
+CASE_SENSE_NAMES       = YES
+
+# If the HIDE_SCOPE_NAMES tag is set to NO then doxygen will show members with
+# their full class and namespace scopes in the documentation. If set to YES, the
+# scope will be hidden.
+# The default value is: NO.
+
+HIDE_SCOPE_NAMES       = NO
+
+# If the HIDE_COMPOUND_REFERENCE tag is set to NO (default) then doxygen will
+# append additional text to a page's title, such as Class Reference. If set to
+# YES the compound reference will be hidden.
+# The default value is: NO.
+
+HIDE_COMPOUND_REFERENCE= NO
+
+# If the SHOW_INCLUDE_FILES tag is set to YES then doxygen will put a list of
+# the files that are included by a file in the documentation of that file.
+# The default value is: YES.
+
+SHOW_INCLUDE_FILES     = YES
+
+# If the SHOW_GROUPED_MEMB_INC tag is set to YES then Doxygen will add for each
+# grouped member an include statement to the documentation, telling the reader
+# which file to include in order to use the member.
+# The default value is: NO.
+
+SHOW_GROUPED_MEMB_INC  = NO
+
+# If the FORCE_LOCAL_INCLUDES tag is set to YES then doxygen will list include
+# files with double quotes in the documentation rather than with sharp brackets.
+# The default value is: NO.
+
+FORCE_LOCAL_INCLUDES   = NO
+
+# If the INLINE_INFO tag is set to YES then a tag [inline] is inserted in the
+# documentation for inline members.
+# The default value is: YES.
+
+INLINE_INFO            = YES
+
+# If the SORT_MEMBER_DOCS tag is set to YES then doxygen will sort the
+# (detailed) documentation of file and class members alphabetically by member
+# name. If set to NO, the members will appear in declaration order.
+# The default value is: YES.
+
+SORT_MEMBER_DOCS       = YES
+
+# If the SORT_BRIEF_DOCS tag is set to YES then doxygen will sort the brief
+# descriptions of file, namespace and class members alphabetically by member
+# name. If set to NO, the members will appear in declaration order. Note that
+# this will also influence the order of the classes in the class list.
+# The default value is: NO.
+
+SORT_BRIEF_DOCS        = NO
+
+# If the SORT_MEMBERS_CTORS_1ST tag is set to YES then doxygen will sort the
+# (brief and detailed) documentation of class members so that constructors and
+# destructors are listed first. If set to NO the constructors will appear in the
+# respective orders defined by SORT_BRIEF_DOCS and SORT_MEMBER_DOCS.
+# Note: If SORT_BRIEF_DOCS is set to NO this option is ignored for sorting brief
+# member documentation.
+# Note: If SORT_MEMBER_DOCS is set to NO this option is ignored for sorting
+# detailed member documentation.
+# The default value is: NO.
+
+SORT_MEMBERS_CTORS_1ST = NO
+
+# If the SORT_GROUP_NAMES tag is set to YES then doxygen will sort the hierarchy
+# of group names into alphabetical order. If set to NO the group names will
+# appear in their defined order.
+# The default value is: NO.
+
+SORT_GROUP_NAMES       = NO
+
+# If the SORT_BY_SCOPE_NAME tag is set to YES, the class list will be sorted by
+# fully-qualified names, including namespaces. If set to NO, the class list will
+# be sorted only by class name, not including the namespace part.
+# Note: This option is not very useful if HIDE_SCOPE_NAMES is set to YES.
+# Note: This option applies only to the class list, not to the alphabetical
+# list.
+# The default value is: NO.
+
+SORT_BY_SCOPE_NAME     = NO
+
+# If the STRICT_PROTO_MATCHING option is enabled and doxygen fails to do proper
+# type resolution of all parameters of a function it will reject a match between
+# the prototype and the implementation of a member function even if there is
+# only one candidate or it is obvious which candidate to choose by doing a
+# simple string match. By disabling STRICT_PROTO_MATCHING doxygen will still
+# accept a match between prototype and implementation in such cases.
+# The default value is: NO.
+
+STRICT_PROTO_MATCHING  = NO
+
+# The GENERATE_TODOLIST tag can be used to enable (YES) or disable (NO) the todo
+# list. This list is created by putting \todo commands in the documentation.
+# The default value is: YES.
+
+GENERATE_TODOLIST      = YES
+
+# The GENERATE_TESTLIST tag can be used to enable (YES) or disable (NO) the test
+# list. This list is created by putting \test commands in the documentation.
+# The default value is: YES.
+
+GENERATE_TESTLIST      = YES
+
+# The GENERATE_BUGLIST tag can be used to enable (YES) or disable (NO) the bug
+# list. This list is created by putting \bug commands in the documentation.
+# The default value is: YES.
+
+GENERATE_BUGLIST       = YES
+
+# The GENERATE_DEPRECATEDLIST tag can be used to enable (YES) or disable (NO)
+# the deprecated list. This list is created by putting \deprecated commands in
+# the documentation.
+# The default value is: YES.
+
+GENERATE_DEPRECATEDLIST= YES
+
+# The ENABLED_SECTIONS tag can be used to enable conditional documentation
+# sections, marked by \if <section_label> ... \endif and \cond <section_label>
+# ... \endcond blocks.
+
+ENABLED_SECTIONS       =
+
+# The MAX_INITIALIZER_LINES tag determines the maximum number of lines that the
+# initial value of a variable or macro / define can have for it to appear in the
+# documentation. If the initializer consists of more lines than specified here
+# it will be hidden. Use a value of 0 to hide initializers completely. The
+# appearance of the value of individual variables and macros / defines can be
+# controlled using \showinitializer or \hideinitializer command in the
+# documentation regardless of this setting.
+# Minimum value: 0, maximum value: 10000, default value: 30.
+
+MAX_INITIALIZER_LINES  = 30
+
+# Set the SHOW_USED_FILES tag to NO to disable the list of files generated at
+# the bottom of the documentation of classes and structs. If set to YES, the
+# list will mention the files that were used to generate the documentation.
+# The default value is: YES.
+
+SHOW_USED_FILES        = YES
+
+# Set the SHOW_FILES tag to NO to disable the generation of the Files page. This
+# will remove the Files entry from the Quick Index and from the Folder Tree View
+# (if specified).
+# The default value is: YES.
+
+SHOW_FILES             = YES
+
+# Set the SHOW_NAMESPACES tag to NO to disable the generation of the Namespaces
+# page. This will remove the Namespaces entry from the Quick Index and from the
+# Folder Tree View (if specified).
+# The default value is: YES.
+
+SHOW_NAMESPACES        = YES
+
+# The FILE_VERSION_FILTER tag can be used to specify a program or script that
+# doxygen should invoke to get the current version for each file (typically from
+# the version control system). Doxygen will invoke the program by executing (via
+# popen()) the command command input-file, where command is the value of the
+# FILE_VERSION_FILTER tag, and input-file is the name of an input file provided
+# by doxygen. Whatever the program writes to standard output is used as the file
+# version. For an example see the documentation.
+
+FILE_VERSION_FILTER    =
+
+# The LAYOUT_FILE tag can be used to specify a layout file which will be parsed
+# by doxygen. The layout file controls the global structure of the generated
+# output files in an output format independent way. To create the layout file
+# that represents doxygen's defaults, run doxygen with the -l option. You can
+# optionally specify a file name after the option, if omitted DoxygenLayout.xml
+# will be used as the name of the layout file.
+#
+# Note that if you run doxygen from a directory containing a file called
+# DoxygenLayout.xml, doxygen will parse it automatically even if the LAYOUT_FILE
+# tag is left empty.
+
+LAYOUT_FILE            =
+
+# The CITE_BIB_FILES tag can be used to specify one or more bib files containing
+# the reference definitions. This must be a list of .bib files. The .bib
+# extension is automatically appended if omitted. This requires the bibtex tool
+# to be installed. See also https://en.wikipedia.org/wiki/BibTeX for more info.
+# For LaTeX the style of the bibliography can be controlled using
+# LATEX_BIB_STYLE. To use this feature you need bibtex and perl available in the
+# search path. See also \cite for info how to create references.
+
+CITE_BIB_FILES         =
+
+#---------------------------------------------------------------------------
+# Configuration options related to warning and progress messages
+#---------------------------------------------------------------------------
+
+# The QUIET tag can be used to turn on/off the messages that are generated to
+# standard output by doxygen. If QUIET is set to YES this implies that the
+# messages are off.
+# The default value is: NO.
+
+QUIET                  = NO
+
+# The WARNINGS tag can be used to turn on/off the warning messages that are
+# generated to standard error (stderr) by doxygen. If WARNINGS is set to YES
+# this implies that the warnings are on.
+#
+# Tip: Turn warnings on while writing the documentation.
+# The default value is: YES.
+
+WARNINGS               = YES
+
+# If the WARN_IF_UNDOCUMENTED tag is set to YES then doxygen will generate
+# warnings for undocumented members. If EXTRACT_ALL is set to YES then this flag
+# will automatically be disabled.
+# The default value is: YES.
+
+WARN_IF_UNDOCUMENTED   = YES
+
+# If the WARN_IF_DOC_ERROR tag is set to YES, doxygen will generate warnings for
+# potential errors in the documentation, such as not documenting some parameters
+# in a documented function, or documenting parameters that don't exist or using
+# markup commands wrongly.
+# The default value is: YES.
+
+WARN_IF_DOC_ERROR      = YES
+
+# This WARN_NO_PARAMDOC option can be enabled to get warnings for functions that
+# are documented, but have no documentation for their parameters or return
+# value. If set to NO, doxygen will only warn about wrong or incomplete
+# parameter documentation, but not about the absence of documentation. If
+# EXTRACT_ALL is set to YES then this flag will automatically be disabled.
+# The default value is: NO.
+
+WARN_NO_PARAMDOC       = NO
+
+# If the WARN_AS_ERROR tag is set to YES then doxygen will immediately stop when
+# a warning is encountered.
+# The default value is: NO.
+
+WARN_AS_ERROR          = NO
+
+# The WARN_FORMAT tag determines the format of the warning messages that doxygen
+# can produce. The string should contain the $file, $line, and $text tags, which
+# will be replaced by the file and line number from which the warning originated
+# and the warning text. Optionally the format may contain $version, which will
+# be replaced by the version of the file (if it could be obtained via
+# FILE_VERSION_FILTER)
+# The default value is: $file:$line: $text.
+
+WARN_FORMAT            = "$file:$line: $text"
+
+# The WARN_LOGFILE tag can be used to specify a file to which warning and error
+# messages should be written. If left blank the output is written to standard
+# error (stderr).
+
+WARN_LOGFILE           =
+
+#---------------------------------------------------------------------------
+# Configuration options related to the input files
+#---------------------------------------------------------------------------
+
+# The INPUT tag is used to specify the files and/or directories that contain
+# documented source files. You may enter file names like myfile.cpp or
+# directories like /usr/src/myproject. Separate the files or directories with
+# spaces. See also FILE_PATTERNS and EXTENSION_MAPPING
+# Note: If this tag is empty the current directory is searched.
+
+INPUT                  =
+
+# This tag can be used to specify the character encoding of the source files
+# that doxygen parses. Internally doxygen uses the UTF-8 encoding. Doxygen uses
+# libiconv (or the iconv built into libc) for the transcoding. See the libiconv
+# documentation (see: https://www.gnu.org/software/libiconv/) for the list of
+# possible encodings.
+# The default value is: UTF-8.
+
+INPUT_ENCODING         = UTF-8
+
+# If the value of the INPUT tag contains directories, you can use the
+# FILE_PATTERNS tag to specify one or more wildcard patterns (like *.cpp and
+# *.h) to filter out the source-files in the directories.
+#
+# Note that for custom extensions or not directly supported extensions you also
+# need to set EXTENSION_MAPPING for the extension otherwise the files are not
+# read by doxygen.
+#
+# If left blank the following patterns are tested:*.c, *.cc, *.cxx, *.cpp,
+# *.c++, *.java, *.ii, *.ixx, *.ipp, *.i++, *.inl, *.idl, *.ddl, *.odl, *.h,
+# *.hh, *.hxx, *.hpp, *.h++, *.cs, *.d, *.php, *.php4, *.php5, *.phtml, *.inc,
+# *.m, *.markdown, *.md, *.mm, *.dox (to be provided as doxygen C comment),
+# *.doc (to be provided as doxygen C comment), *.txt (to be provided as doxygen
+# C comment), *.py, *.pyw, *.f90, *.f95, *.f03, *.f08, *.f, *.for, *.tcl, *.vhd,
+# *.vhdl, *.ucf, *.qsf and *.ice.
+
+FILE_PATTERNS          = *.c \
+                         *.cc \
+                         *.cxx \
+                         *.cpp \
+                         *.c++ \
+                         *.java \
+                         *.ii \
+                         *.ixx \
+                         *.ipp \
+                         *.i++ \
+                         *.inl \
+                         *.idl \
+                         *.ddl \
+                         *.odl \
+                         *.h \
+                         *.hh \
+                         *.hxx \
+                         *.hpp \
+                         *.h++ \
+                         *.cs \
+                         *.d \
+                         *.php \
+                         *.php4 \
+                         *.php5 \
+                         *.phtml \
+                         *.inc \
+                         *.m \
+                         *.markdown \
+                         *.md \
+                         *.mm \
+                         *.dox \
+                         *.doc \
+                         *.txt \
+                         *.py \
+                         *.pyw \
+                         *.f90 \
+                         *.f95 \
+                         *.f03 \
+                         *.f08 \
+                         *.f \
+                         *.for \
+                         *.tcl \
+                         *.vhd \
+                         *.vhdl \
+                         *.ucf \
+                         *.qsf \
+                         *.ice
+
+# The RECURSIVE tag can be used to specify whether or not subdirectories should
+# be searched for input files as well.
+# The default value is: NO.
+
+RECURSIVE              = YES
+
+# The EXCLUDE tag can be used to specify files and/or directories that should be
+# excluded from the INPUT source files. This way you can easily exclude a
+# subdirectory from a directory tree whose root is specified with the INPUT tag.
+#
+# Note that relative paths are relative to the directory from which doxygen is
+# run.
+
+EXCLUDE                =
+
+# The EXCLUDE_SYMLINKS tag can be used to select whether or not files or
+# directories that are symbolic links (a Unix file system feature) are excluded
+# from the input.
+# The default value is: NO.
+
+EXCLUDE_SYMLINKS       = NO
+
+# If the value of the INPUT tag contains directories, you can use the
+# EXCLUDE_PATTERNS tag to specify one or more wildcard patterns to exclude
+# certain files from those directories.
+#
+# Note that the wildcards are matched against the file with absolute path, so to
+# exclude all test directories for example use the pattern */test/*
+
+EXCLUDE_PATTERNS       =
+
+# The EXCLUDE_SYMBOLS tag can be used to specify one or more symbol names
+# (namespaces, classes, functions, etc.) that should be excluded from the
+# output. The symbol name can be a fully qualified name, a word, or if the
+# wildcard * is used, a substring. Examples: ANamespace, AClass,
+# AClass::ANamespace, ANamespace::*Test
+#
+# Note that the wildcards are matched against the file with absolute path, so to
+# exclude all test directories use the pattern */test/*
+
+EXCLUDE_SYMBOLS        =
+
+# The EXAMPLE_PATH tag can be used to specify one or more files or directories
+# that contain example code fragments that are included (see the \include
+# command).
+
+EXAMPLE_PATH           =
+
+# If the value of the EXAMPLE_PATH tag contains directories, you can use the
+# EXAMPLE_PATTERNS tag to specify one or more wildcard pattern (like *.cpp and
+# *.h) to filter out the source-files in the directories. If left blank all
+# files are included.
+
+EXAMPLE_PATTERNS       = *
+
+# If the EXAMPLE_RECURSIVE tag is set to YES then subdirectories will be
+# searched for input files to be used with the \include or \dontinclude commands
+# irrespective of the value of the RECURSIVE tag.
+# The default value is: NO.
+
+EXAMPLE_RECURSIVE      = NO
+
+# The IMAGE_PATH tag can be used to specify one or more files or directories
+# that contain images that are to be included in the documentation (see the
+# \image command).
+
+IMAGE_PATH             =
+
+# The INPUT_FILTER tag can be used to specify a program that doxygen should
+# invoke to filter for each input file. Doxygen will invoke the filter program
+# by executing (via popen()) the command:
+#
+# <filter> <input-file>
+#
+# where <filter> is the value of the INPUT_FILTER tag, and <input-file> is the
+# name of an input file. Doxygen will then use the output that the filter
+# program writes to standard output. If FILTER_PATTERNS is specified, this tag
+# will be ignored.
+#
+# Note that the filter must not add or remove lines; it is applied before the
+# code is scanned, but not when the output code is generated. If lines are added
+# or removed, the anchors will not be placed correctly.
+#
+# Note that for custom extensions or not directly supported extensions you also
+# need to set EXTENSION_MAPPING for the extension otherwise the files are not
+# properly processed by doxygen.
+
+INPUT_FILTER           =
+
+# The FILTER_PATTERNS tag can be used to specify filters on a per file pattern
+# basis. Doxygen will compare the file name with each pattern and apply the
+# filter if there is a match. The filters are a list of the form: pattern=filter
+# (like *.cpp=my_cpp_filter). See INPUT_FILTER for further information on how
+# filters are used. If the FILTER_PATTERNS tag is empty or if none of the
+# patterns match the file name, INPUT_FILTER is applied.
+#
+# Note that for custom extensions or not directly supported extensions you also
+# need to set EXTENSION_MAPPING for the extension otherwise the files are not
+# properly processed by doxygen.
+
+FILTER_PATTERNS        =
+
+# If the FILTER_SOURCE_FILES tag is set to YES, the input filter (if set using
+# INPUT_FILTER) will also be used to filter the input files that are used for
+# producing the source files to browse (i.e. when SOURCE_BROWSER is set to YES).
+# The default value is: NO.
+
+FILTER_SOURCE_FILES    = NO
+
+# The FILTER_SOURCE_PATTERNS tag can be used to specify source filters per file
+# pattern. A pattern will override the setting for FILTER_PATTERN (if any) and
+# it is also possible to disable source filtering for a specific pattern using
+# *.ext= (so without naming a filter).
+# This tag requires that the tag FILTER_SOURCE_FILES is set to YES.
+
+FILTER_SOURCE_PATTERNS =
+
+# If the USE_MDFILE_AS_MAINPAGE tag refers to the name of a markdown file that
+# is part of the input, its contents will be placed on the main page
+# (index.html). This can be useful if you have a project on for instance GitHub
+# and want to reuse the introduction page also for the doxygen output.
+
+USE_MDFILE_AS_MAINPAGE =
+
+#---------------------------------------------------------------------------
+# Configuration options related to source browsing
+#---------------------------------------------------------------------------
+
+# If the SOURCE_BROWSER tag is set to YES then a list of source files will be
+# generated. Documented entities will be cross-referenced with these sources.
+#
+# Note: To get rid of all source code in the generated output, make sure that
+# also VERBATIM_HEADERS is set to NO.
+# The default value is: NO.
+
+SOURCE_BROWSER         = NO
+
+# Setting the INLINE_SOURCES tag to YES will include the body of functions,
+# classes and enums directly into the documentation.
+# The default value is: NO.
+
+INLINE_SOURCES         = NO
+
+# Setting the STRIP_CODE_COMMENTS tag to YES will instruct doxygen to hide any
+# special comment blocks from generated source code fragments. Normal C, C++ and
+# Fortran comments will always remain visible.
+# The default value is: YES.
+
+STRIP_CODE_COMMENTS    = YES
+
+# If the REFERENCED_BY_RELATION tag is set to YES then for each documented
+# entity all documented functions referencing it will be listed.
+# The default value is: NO.
+
+REFERENCED_BY_RELATION = NO
+
+# If the REFERENCES_RELATION tag is set to YES then for each documented function
+# all documented entities called/used by that function will be listed.
+# The default value is: NO.
+
+REFERENCES_RELATION    = NO
+
+# If the REFERENCES_LINK_SOURCE tag is set to YES and SOURCE_BROWSER tag is set
+# to YES then the hyperlinks from functions in REFERENCES_RELATION and
+# REFERENCED_BY_RELATION lists will link to the source code. Otherwise they will
+# link to the documentation.
+# The default value is: YES.
+
+REFERENCES_LINK_SOURCE = YES
+
+# If SOURCE_TOOLTIPS is enabled (the default) then hovering a hyperlink in the
+# source code will show a tooltip with additional information such as prototype,
+# brief description and links to the definition and documentation. Since this
+# will make the HTML file larger and loading of large files a bit slower, you
+# can opt to disable this feature.
+# The default value is: YES.
+# This tag requires that the tag SOURCE_BROWSER is set to YES.
+
+SOURCE_TOOLTIPS        = YES
+
+# If the USE_HTAGS tag is set to YES then the references to source code will
+# point to the HTML generated by the htags(1) tool instead of doxygen built-in
+# source browser. The htags tool is part of GNU's global source tagging system
+# (see https://www.gnu.org/software/global/global.html). You will need version
+# 4.8.6 or higher.
+#
+# To use it do the following:
+# - Install the latest version of global
+# - Enable SOURCE_BROWSER and USE_HTAGS in the configuration file
+# - Make sure the INPUT points to the root of the source tree
+# - Run doxygen as normal
+#
+# Doxygen will invoke htags (and that will in turn invoke gtags), so these
+# tools must be available from the command line (i.e. in the search path).
+#
+# The result: instead of the source browser generated by doxygen, the links to
+# source code will now point to the output of htags.
+# The default value is: NO.
+# This tag requires that the tag SOURCE_BROWSER is set to YES.
+
+USE_HTAGS              = NO
+
+# If the VERBATIM_HEADERS tag is set the YES then doxygen will generate a
+# verbatim copy of the header file for each class for which an include is
+# specified. Set to NO to disable this.
+# See also: Section \class.
+# The default value is: YES.
+
+VERBATIM_HEADERS       = YES
+
+# If the CLANG_ASSISTED_PARSING tag is set to YES then doxygen will use the
+# clang parser (see: http://clang.llvm.org/) for more accurate parsing at the
+# cost of reduced performance. This can be particularly helpful with template
+# rich C++ code for which doxygen's built-in parser lacks the necessary type
+# information.
+# Note: The availability of this option depends on whether or not doxygen was
+# generated with the -Duse_libclang=ON option for CMake.
+# The default value is: NO.
+
+CLANG_ASSISTED_PARSING = NO
+
+# If clang assisted parsing is enabled you can provide the compiler with command
+# line options that you would normally use when invoking the compiler. Note that
+# the include paths will already be set by doxygen for the files and directories
+# specified with INPUT and INCLUDE_PATH.
+# This tag requires that the tag CLANG_ASSISTED_PARSING is set to YES.
+
+CLANG_OPTIONS          =
+
+# If clang assisted parsing is enabled you can provide the clang parser with the
+# path to the compilation database (see:
+# http://clang.llvm.org/docs/HowToSetupToolingForLLVM.html) used when the files
+# were built. This is equivalent to specifying the "-p" option to a clang tool,
+# such as clang-check. These options will then be passed to the parser.
+# Note: The availability of this option depends on whether or not doxygen was
+# generated with the -Duse_libclang=ON option for CMake.
+
+CLANG_DATABASE_PATH    =
+
+#---------------------------------------------------------------------------
+# Configuration options related to the alphabetical class index
+#---------------------------------------------------------------------------
+
+# If the ALPHABETICAL_INDEX tag is set to YES, an alphabetical index of all
+# compounds will be generated. Enable this if the project contains a lot of
+# classes, structs, unions or interfaces.
+# The default value is: YES.
+
+ALPHABETICAL_INDEX     = YES
+
+# The COLS_IN_ALPHA_INDEX tag can be used to specify the number of columns in
+# which the alphabetical index list will be split.
+# Minimum value: 1, maximum value: 20, default value: 5.
+# This tag requires that the tag ALPHABETICAL_INDEX is set to YES.
+
+COLS_IN_ALPHA_INDEX    = 5
+
+# In case all classes in a project start with a common prefix, all classes will
+# be put under the same header in the alphabetical index. The IGNORE_PREFIX tag
+# can be used to specify a prefix (or a list of prefixes) that should be ignored
+# while generating the index headers.
+# This tag requires that the tag ALPHABETICAL_INDEX is set to YES.
+
+IGNORE_PREFIX          =
+
+#---------------------------------------------------------------------------
+# Configuration options related to the HTML output
+#---------------------------------------------------------------------------
+
+# If the GENERATE_HTML tag is set to YES, doxygen will generate HTML output
+# The default value is: YES.
+
+GENERATE_HTML          = YES
+
+# The HTML_OUTPUT tag is used to specify where the HTML docs will be put. If a
+# relative path is entered the value of OUTPUT_DIRECTORY will be put in front of
+# it.
+# The default directory is: html.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_OUTPUT            = html
+
+# The HTML_FILE_EXTENSION tag can be used to specify the file extension for each
+# generated HTML page (for example: .htm, .php, .asp).
+# The default value is: .html.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_FILE_EXTENSION    = .html
+
+# The HTML_HEADER tag can be used to specify a user-defined HTML header file for
+# each generated HTML page. If the tag is left blank doxygen will generate a
+# standard header.
+#
+# To get valid HTML the header file that includes any scripts and style sheets
+# that doxygen needs, which is dependent on the configuration options used (e.g.
+# the setting GENERATE_TREEVIEW). It is highly recommended to start with a
+# default header using
+# doxygen -w html new_header.html new_footer.html new_stylesheet.css
+# YourConfigFile
+# and then modify the file new_header.html. See also section "Doxygen usage"
+# for information on how to generate the default header that doxygen normally
+# uses.
+# Note: The header is subject to change so you typically have to regenerate the
+# default header when upgrading to a newer version of doxygen. For a description
+# of the possible markers and block names see the documentation.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_HEADER            =
+
+# The HTML_FOOTER tag can be used to specify a user-defined HTML footer for each
+# generated HTML page. If the tag is left blank doxygen will generate a standard
+# footer. See HTML_HEADER for more information on how to generate a default
+# footer and what special commands can be used inside the footer. See also
+# section "Doxygen usage" for information on how to generate the default footer
+# that doxygen normally uses.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_FOOTER            =
+
+# The HTML_STYLESHEET tag can be used to specify a user-defined cascading style
+# sheet that is used by each HTML page. It can be used to fine-tune the look of
+# the HTML output. If left blank doxygen will generate a default style sheet.
+# See also section "Doxygen usage" for information on how to generate the style
+# sheet that doxygen normally uses.
+# Note: It is recommended to use HTML_EXTRA_STYLESHEET instead of this tag, as
+# it is more robust and this tag (HTML_STYLESHEET) will in the future become
+# obsolete.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_STYLESHEET        =
+
+# The HTML_EXTRA_STYLESHEET tag can be used to specify additional user-defined
+# cascading style sheets that are included after the standard style sheets
+# created by doxygen. Using this option one can overrule certain style aspects.
+# This is preferred over using HTML_STYLESHEET since it does not replace the
+# standard style sheet and is therefore more robust against future updates.
+# Doxygen will copy the style sheet files to the output directory.
+# Note: The order of the extra style sheet files is of importance (e.g. the last
+# style sheet in the list overrules the setting of the previous ones in the
+# list). For an example see the documentation.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_EXTRA_STYLESHEET  =
+
+# The HTML_EXTRA_FILES tag can be used to specify one or more extra images or
+# other source files which should be copied to the HTML output directory. Note
+# that these files will be copied to the base HTML output directory. Use the
+# $relpath^ marker in the HTML_HEADER and/or HTML_FOOTER files to load these
+# files. In the HTML_STYLESHEET file, use the file name only. Also note that the
+# files will be copied as-is; there are no commands or markers available.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_EXTRA_FILES       =
+
+# The HTML_COLORSTYLE_HUE tag controls the color of the HTML output. Doxygen
+# will adjust the colors in the style sheet and background images according to
+# this color. Hue is specified as an angle on a colorwheel, see
+# https://en.wikipedia.org/wiki/Hue for more information. For instance the value
+# 0 represents red, 60 is yellow, 120 is green, 180 is cyan, 240 is blue, 300
+# purple, and 360 is red again.
+# Minimum value: 0, maximum value: 359, default value: 220.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_COLORSTYLE_HUE    = 220
+
+# The HTML_COLORSTYLE_SAT tag controls the purity (or saturation) of the colors
+# in the HTML output. For a value of 0 the output will use grayscales only. A
+# value of 255 will produce the most vivid colors.
+# Minimum value: 0, maximum value: 255, default value: 100.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_COLORSTYLE_SAT    = 100
+
+# The HTML_COLORSTYLE_GAMMA tag controls the gamma correction applied to the
+# luminance component of the colors in the HTML output. Values below 100
+# gradually make the output lighter, whereas values above 100 make the output
+# darker. The value divided by 100 is the actual gamma applied, so 80 represents
+# a gamma of 0.8, The value 220 represents a gamma of 2.2, and 100 does not
+# change the gamma.
+# Minimum value: 40, maximum value: 240, default value: 80.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_COLORSTYLE_GAMMA  = 80
+
+# If the HTML_TIMESTAMP tag is set to YES then the footer of each generated HTML
+# page will contain the date and time when the page was generated. Setting this
+# to YES can help to show when doxygen was last run and thus if the
+# documentation is up to date.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_TIMESTAMP         = NO
+
+# If the HTML_DYNAMIC_MENUS tag is set to YES then the generated HTML
+# documentation will contain a main index with vertical navigation menus that
+# are dynamically created via JavaScript. If disabled, the navigation index will
+# consists of multiple levels of tabs that are statically embedded in every HTML
+# page. Disable this option to support browsers that do not have JavaScript,
+# like the Qt help browser.
+# The default value is: YES.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_DYNAMIC_MENUS     = YES
+
+# If the HTML_DYNAMIC_SECTIONS tag is set to YES then the generated HTML
+# documentation will contain sections that can be hidden and shown after the
+# page has loaded.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_DYNAMIC_SECTIONS  = NO
+
+# With HTML_INDEX_NUM_ENTRIES one can control the preferred number of entries
+# shown in the various tree structured indices initially; the user can expand
+# and collapse entries dynamically later on. Doxygen will expand the tree to
+# such a level that at most the specified number of entries are visible (unless
+# a fully collapsed tree already exceeds this amount). So setting the number of
+# entries 1 will produce a full collapsed tree by default. 0 is a special value
+# representing an infinite number of entries and will result in a full expanded
+# tree by default.
+# Minimum value: 0, maximum value: 9999, default value: 100.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+HTML_INDEX_NUM_ENTRIES = 100
+
+# If the GENERATE_DOCSET tag is set to YES, additional index files will be
+# generated that can be used as input for Apple's Xcode 3 integrated development
+# environment (see: https://developer.apple.com/xcode/), introduced with OSX
+# 10.5 (Leopard). To create a documentation set, doxygen will generate a
+# Makefile in the HTML output directory. Running make will produce the docset in
+# that directory and running make install will install the docset in
+# ~/Library/Developer/Shared/Documentation/DocSets so that Xcode will find it at
+# startup. See https://developer.apple.com/library/archive/featuredarticles/Doxy
+# genXcode/_index.html for more information.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+GENERATE_DOCSET        = NO
+
+# This tag determines the name of the docset feed. A documentation feed provides
+# an umbrella under which multiple documentation sets from a single provider
+# (such as a company or product suite) can be grouped.
+# The default value is: Doxygen generated docs.
+# This tag requires that the tag GENERATE_DOCSET is set to YES.
+
+DOCSET_FEEDNAME        = "Doxygen generated docs"
+
+# This tag specifies a string that should uniquely identify the documentation
+# set bundle. This should be a reverse domain-name style string, e.g.
+# com.mycompany.MyDocSet. Doxygen will append .docset to the name.
+# The default value is: org.doxygen.Project.
+# This tag requires that the tag GENERATE_DOCSET is set to YES.
+
+DOCSET_BUNDLE_ID       = org.doxygen.Project
+
+# The DOCSET_PUBLISHER_ID tag specifies a string that should uniquely identify
+# the documentation publisher. This should be a reverse domain-name style
+# string, e.g. com.mycompany.MyDocSet.documentation.
+# The default value is: org.doxygen.Publisher.
+# This tag requires that the tag GENERATE_DOCSET is set to YES.
+
+DOCSET_PUBLISHER_ID    = org.doxygen.Publisher
+
+# The DOCSET_PUBLISHER_NAME tag identifies the documentation publisher.
+# The default value is: Publisher.
+# This tag requires that the tag GENERATE_DOCSET is set to YES.
+
+DOCSET_PUBLISHER_NAME  = Publisher
+
+# If the GENERATE_HTMLHELP tag is set to YES then doxygen generates three
+# additional HTML index files: index.hhp, index.hhc, and index.hhk. The
+# index.hhp is a project file that can be read by Microsoft's HTML Help Workshop
+# (see: https://www.microsoft.com/en-us/download/details.aspx?id=21138) on
+# Windows.
+#
+# The HTML Help Workshop contains a compiler that can convert all HTML output
+# generated by doxygen into a single compiled HTML file (.chm). Compiled HTML
+# files are now used as the Windows 98 help format, and will replace the old
+# Windows help format (.hlp) on all Windows platforms in the future. Compressed
+# HTML files also contain an index, a table of contents, and you can search for
+# words in the documentation. The HTML workshop also contains a viewer for
+# compressed HTML files.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+GENERATE_HTMLHELP      = NO
+
+# The CHM_FILE tag can be used to specify the file name of the resulting .chm
+# file. You can add a path in front of the file if the result should not be
+# written to the html output directory.
+# This tag requires that the tag GENERATE_HTMLHELP is set to YES.
+
+CHM_FILE               =
+
+# The HHC_LOCATION tag can be used to specify the location (absolute path
+# including file name) of the HTML help compiler (hhc.exe). If non-empty,
+# doxygen will try to run the HTML help compiler on the generated index.hhp.
+# The file has to be specified with full path.
+# This tag requires that the tag GENERATE_HTMLHELP is set to YES.
+
+HHC_LOCATION           =
+
+# The GENERATE_CHI flag controls if a separate .chi index file is generated
+# (YES) or that it should be included in the master .chm file (NO).
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTMLHELP is set to YES.
+
+GENERATE_CHI           = NO
+
+# The CHM_INDEX_ENCODING is used to encode HtmlHelp index (hhk), content (hhc)
+# and project file content.
+# This tag requires that the tag GENERATE_HTMLHELP is set to YES.
+
+CHM_INDEX_ENCODING     =
+
+# The BINARY_TOC flag controls whether a binary table of contents is generated
+# (YES) or a normal table of contents (NO) in the .chm file. Furthermore it
+# enables the Previous and Next buttons.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTMLHELP is set to YES.
+
+BINARY_TOC             = NO
+
+# The TOC_EXPAND flag can be set to YES to add extra items for group members to
+# the table of contents of the HTML help documentation and to the tree view.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTMLHELP is set to YES.
+
+TOC_EXPAND             = NO
+
+# If the GENERATE_QHP tag is set to YES and both QHP_NAMESPACE and
+# QHP_VIRTUAL_FOLDER are set, an additional index file will be generated that
+# can be used as input for Qt's qhelpgenerator to generate a Qt Compressed Help
+# (.qch) of the generated HTML documentation.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+GENERATE_QHP           = NO
+
+# If the QHG_LOCATION tag is specified, the QCH_FILE tag can be used to specify
+# the file name of the resulting .qch file. The path specified is relative to
+# the HTML output folder.
+# This tag requires that the tag GENERATE_QHP is set to YES.
+
+QCH_FILE               =
+
+# The QHP_NAMESPACE tag specifies the namespace to use when generating Qt Help
+# Project output. For more information please see Qt Help Project / Namespace
+# (see: https://doc.qt.io/archives/qt-4.8/qthelpproject.html#namespace).
+# The default value is: org.doxygen.Project.
+# This tag requires that the tag GENERATE_QHP is set to YES.
+
+QHP_NAMESPACE          = org.doxygen.Project
+
+# The QHP_VIRTUAL_FOLDER tag specifies the namespace to use when generating Qt
+# Help Project output. For more information please see Qt Help Project / Virtual
+# Folders (see: https://doc.qt.io/archives/qt-4.8/qthelpproject.html#virtual-
+# folders).
+# The default value is: doc.
+# This tag requires that the tag GENERATE_QHP is set to YES.
+
+QHP_VIRTUAL_FOLDER     = doc
+
+# If the QHP_CUST_FILTER_NAME tag is set, it specifies the name of a custom
+# filter to add. For more information please see Qt Help Project / Custom
+# Filters (see: https://doc.qt.io/archives/qt-4.8/qthelpproject.html#custom-
+# filters).
+# This tag requires that the tag GENERATE_QHP is set to YES.
+
+QHP_CUST_FILTER_NAME   =
+
+# The QHP_CUST_FILTER_ATTRS tag specifies the list of the attributes of the
+# custom filter to add. For more information please see Qt Help Project / Custom
+# Filters (see: https://doc.qt.io/archives/qt-4.8/qthelpproject.html#custom-
+# filters).
+# This tag requires that the tag GENERATE_QHP is set to YES.
+
+QHP_CUST_FILTER_ATTRS  =
+
+# The QHP_SECT_FILTER_ATTRS tag specifies the list of the attributes this
+# project's filter section matches. Qt Help Project / Filter Attributes (see:
+# https://doc.qt.io/archives/qt-4.8/qthelpproject.html#filter-attributes).
+# This tag requires that the tag GENERATE_QHP is set to YES.
+
+QHP_SECT_FILTER_ATTRS  =
+
+# The QHG_LOCATION tag can be used to specify the location of Qt's
+# qhelpgenerator. If non-empty doxygen will try to run qhelpgenerator on the
+# generated .qhp file.
+# This tag requires that the tag GENERATE_QHP is set to YES.
+
+QHG_LOCATION           =
+
+# If the GENERATE_ECLIPSEHELP tag is set to YES, additional index files will be
+# generated, together with the HTML files, they form an Eclipse help plugin. To
+# install this plugin and make it available under the help contents menu in
+# Eclipse, the contents of the directory containing the HTML and XML files needs
+# to be copied into the plugins directory of eclipse. The name of the directory
+# within the plugins directory should be the same as the ECLIPSE_DOC_ID value.
+# After copying Eclipse needs to be restarted before the help appears.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+GENERATE_ECLIPSEHELP   = NO
+
+# A unique identifier for the Eclipse help plugin. When installing the plugin
+# the directory name containing the HTML and XML files should also have this
+# name. Each documentation set should have its own identifier.
+# The default value is: org.doxygen.Project.
+# This tag requires that the tag GENERATE_ECLIPSEHELP is set to YES.
+
+ECLIPSE_DOC_ID         = org.doxygen.Project
+
+# If you want full control over the layout of the generated HTML pages it might
+# be necessary to disable the index and replace it with your own. The
+# DISABLE_INDEX tag can be used to turn on/off the condensed index (tabs) at top
+# of each HTML page. A value of NO enables the index and the value YES disables
+# it. Since the tabs in the index contain the same information as the navigation
+# tree, you can set this option to YES if you also set GENERATE_TREEVIEW to YES.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+DISABLE_INDEX          = NO
+
+# The GENERATE_TREEVIEW tag is used to specify whether a tree-like index
+# structure should be generated to display hierarchical information. If the tag
+# value is set to YES, a side panel will be generated containing a tree-like
+# index structure (just like the one that is generated for HTML Help). For this
+# to work a browser that supports JavaScript, DHTML, CSS and frames is required
+# (i.e. any modern browser). Windows users are probably better off using the
+# HTML help feature. Via custom style sheets (see HTML_EXTRA_STYLESHEET) one can
+# further fine-tune the look of the index. As an example, the default style
+# sheet generated by doxygen has an example that shows how to put an image at
+# the root of the tree instead of the PROJECT_NAME. Since the tree basically has
+# the same information as the tab index, you could consider setting
+# DISABLE_INDEX to YES when enabling this option.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+GENERATE_TREEVIEW      = NO
+
+# The ENUM_VALUES_PER_LINE tag can be used to set the number of enum values that
+# doxygen will group on one line in the generated HTML documentation.
+#
+# Note that a value of 0 will completely suppress the enum values from appearing
+# in the overview section.
+# Minimum value: 0, maximum value: 20, default value: 4.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+ENUM_VALUES_PER_LINE   = 4
+
+# If the treeview is enabled (see GENERATE_TREEVIEW) then this tag can be used
+# to set the initial width (in pixels) of the frame in which the tree is shown.
+# Minimum value: 0, maximum value: 1500, default value: 250.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+TREEVIEW_WIDTH         = 250
+
+# If the EXT_LINKS_IN_WINDOW option is set to YES, doxygen will open links to
+# external symbols imported via tag files in a separate window.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+EXT_LINKS_IN_WINDOW    = NO
+
+# Use this tag to change the font size of LaTeX formulas included as images in
+# the HTML documentation. When you change the font size after a successful
+# doxygen run you need to manually remove any form_*.png images from the HTML
+# output directory to force them to be regenerated.
+# Minimum value: 8, maximum value: 50, default value: 10.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+FORMULA_FONTSIZE       = 10
+
+# Use the FORMULA_TRANSPARENT tag to determine whether or not the images
+# generated for formulas are transparent PNGs. Transparent PNGs are not
+# supported properly for IE 6.0, but are supported on all modern browsers.
+#
+# Note that when changing this option you need to delete any form_*.png files in
+# the HTML output directory before the changes have effect.
+# The default value is: YES.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+FORMULA_TRANSPARENT    = YES
+
+# The FORMULA_MACROFILE can contain LaTeX \newcommand and \renewcommand commands
+# to create new LaTeX commands to be used in formulas as building blocks. See
+# the section "Including formulas" for details.
+
+FORMULA_MACROFILE      =
+
+# Enable the USE_MATHJAX option to render LaTeX formulas using MathJax (see
+# https://www.mathjax.org) which uses client side JavaScript for the rendering
+# instead of using pre-rendered bitmaps. Use this if you do not have LaTeX
+# installed or if you want to formulas look prettier in the HTML output. When
+# enabled you may also need to install MathJax separately and configure the path
+# to it using the MATHJAX_RELPATH option.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+USE_MATHJAX            = YES
+
+# When MathJax is enabled you can set the default output format to be used for
+# the MathJax output. See the MathJax site (see:
+# http://docs.mathjax.org/en/latest/output.html) for more details.
+# Possible values are: HTML-CSS (which is slower, but has the best
+# compatibility), NativeMML (i.e. MathML) and SVG.
+# The default value is: HTML-CSS.
+# This tag requires that the tag USE_MATHJAX is set to YES.
+
+MATHJAX_FORMAT         = HTML-CSS
+
+# When MathJax is enabled you need to specify the location relative to the HTML
+# output directory using the MATHJAX_RELPATH option. The destination directory
+# should contain the MathJax.js script. For instance, if the mathjax directory
+# is located at the same level as the HTML output directory, then
+# MATHJAX_RELPATH should be ../mathjax. The default value points to the MathJax
+# Content Delivery Network so you can quickly see the result without installing
+# MathJax. However, it is strongly recommended to install a local copy of
+# MathJax from https://www.mathjax.org before deployment.
+# The default value is: https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/.
+# This tag requires that the tag USE_MATHJAX is set to YES.
+
+MATHJAX_RELPATH        = https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/
+
+# The MATHJAX_EXTENSIONS tag can be used to specify one or more MathJax
+# extension names that should be enabled during MathJax rendering. For example
+# MATHJAX_EXTENSIONS = TeX/AMSmath TeX/AMSsymbols
+# This tag requires that the tag USE_MATHJAX is set to YES.
+
+MATHJAX_EXTENSIONS     =
+
+# The MATHJAX_CODEFILE tag can be used to specify a file with javascript pieces
+# of code that will be used on startup of the MathJax code. See the MathJax site
+# (see: http://docs.mathjax.org/en/latest/output.html) for more details. For an
+# example see the documentation.
+# This tag requires that the tag USE_MATHJAX is set to YES.
+
+MATHJAX_CODEFILE       =
+
+# When the SEARCHENGINE tag is enabled doxygen will generate a search box for
+# the HTML output. The underlying search engine uses javascript and DHTML and
+# should work on any modern browser. Note that when using HTML help
+# (GENERATE_HTMLHELP), Qt help (GENERATE_QHP), or docsets (GENERATE_DOCSET)
+# there is already a search function so this one should typically be disabled.
+# For large projects the javascript based search engine can be slow, then
+# enabling SERVER_BASED_SEARCH may provide a better solution. It is possible to
+# search using the keyboard; to jump to the search box use <access key> + S
+# (what the <access key> is depends on the OS and browser, but it is typically
+# <CTRL>, <ALT>/<option>, or both). Inside the search box use the <cursor down
+# key> to jump into the search results window, the results can be navigated
+# using the <cursor keys>. Press <Enter> to select an item or <escape> to cancel
+# the search. The filter options can be selected when the cursor is inside the
+# search box by pressing <Shift>+<cursor down>. Also here use the <cursor keys>
+# to select a filter and <Enter> or <escape> to activate or cancel the filter
+# option.
+# The default value is: YES.
+# This tag requires that the tag GENERATE_HTML is set to YES.
+
+SEARCHENGINE           = YES
+
+# When the SERVER_BASED_SEARCH tag is enabled the search engine will be
+# implemented using a web server instead of a web client using JavaScript. There
+# are two flavors of web server based searching depending on the EXTERNAL_SEARCH
+# setting. When disabled, doxygen will generate a PHP script for searching and
+# an index file used by the script. When EXTERNAL_SEARCH is enabled the indexing
+# and searching needs to be provided by external tools. See the section
+# "External Indexing and Searching" for details.
+# The default value is: NO.
+# This tag requires that the tag SEARCHENGINE is set to YES.
+
+SERVER_BASED_SEARCH    = NO
+
+# When EXTERNAL_SEARCH tag is enabled doxygen will no longer generate the PHP
+# script for searching. Instead the search results are written to an XML file
+# which needs to be processed by an external indexer. Doxygen will invoke an
+# external search engine pointed to by the SEARCHENGINE_URL option to obtain the
+# search results.
+#
+# Doxygen ships with an example indexer (doxyindexer) and search engine
+# (doxysearch.cgi) which are based on the open source search engine library
+# Xapian (see: https://xapian.org/).
+#
+# See the section "External Indexing and Searching" for details.
+# The default value is: NO.
+# This tag requires that the tag SEARCHENGINE is set to YES.
+
+EXTERNAL_SEARCH        = NO
+
+# The SEARCHENGINE_URL should point to a search engine hosted by a web server
+# which will return the search results when EXTERNAL_SEARCH is enabled.
+#
+# Doxygen ships with an example indexer (doxyindexer) and search engine
+# (doxysearch.cgi) which are based on the open source search engine library
+# Xapian (see: https://xapian.org/). See the section "External Indexing and
+# Searching" for details.
+# This tag requires that the tag SEARCHENGINE is set to YES.
+
+SEARCHENGINE_URL       =
+
+# When SERVER_BASED_SEARCH and EXTERNAL_SEARCH are both enabled the unindexed
+# search data is written to a file for indexing by an external tool. With the
+# SEARCHDATA_FILE tag the name of this file can be specified.
+# The default file is: searchdata.xml.
+# This tag requires that the tag SEARCHENGINE is set to YES.
+
+SEARCHDATA_FILE        = searchdata.xml
+
+# When SERVER_BASED_SEARCH and EXTERNAL_SEARCH are both enabled the
+# EXTERNAL_SEARCH_ID tag can be used as an identifier for the project. This is
+# useful in combination with EXTRA_SEARCH_MAPPINGS to search through multiple
+# projects and redirect the results back to the right project.
+# This tag requires that the tag SEARCHENGINE is set to YES.
+
+EXTERNAL_SEARCH_ID     =
+
+# The EXTRA_SEARCH_MAPPINGS tag can be used to enable searching through doxygen
+# projects other than the one defined by this configuration file, but that are
+# all added to the same external search index. Each project needs to have a
+# unique id set via EXTERNAL_SEARCH_ID. The search mapping then maps the id of
+# to a relative location where the documentation can be found. The format is:
+# EXTRA_SEARCH_MAPPINGS = tagname1=loc1 tagname2=loc2 ...
+# This tag requires that the tag SEARCHENGINE is set to YES.
+
+EXTRA_SEARCH_MAPPINGS  =
+
+#---------------------------------------------------------------------------
+# Configuration options related to the LaTeX output
+#---------------------------------------------------------------------------
+
+# If the GENERATE_LATEX tag is set to YES, doxygen will generate LaTeX output.
+# The default value is: YES.
+
+GENERATE_LATEX         = YES
+
+# The LATEX_OUTPUT tag is used to specify where the LaTeX docs will be put. If a
+# relative path is entered the value of OUTPUT_DIRECTORY will be put in front of
+# it.
+# The default directory is: latex.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_OUTPUT           = latex
+
+# The LATEX_CMD_NAME tag can be used to specify the LaTeX command name to be
+# invoked.
+#
+# Note that when not enabling USE_PDFLATEX the default is latex when enabling
+# USE_PDFLATEX the default is pdflatex and when in the later case latex is
+# chosen this is overwritten by pdflatex. For specific output languages the
+# default can have been set differently, this depends on the implementation of
+# the output language.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_CMD_NAME         =
+
+# The MAKEINDEX_CMD_NAME tag can be used to specify the command name to generate
+# index for LaTeX.
+# Note: This tag is used in the Makefile / make.bat.
+# See also: LATEX_MAKEINDEX_CMD for the part in the generated output file
+# (.tex).
+# The default file is: makeindex.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+MAKEINDEX_CMD_NAME     = makeindex
+
+# The LATEX_MAKEINDEX_CMD tag can be used to specify the command name to
+# generate index for LaTeX. In case there is no backslash (\) as first character
+# it will be automatically added in the LaTeX code.
+# Note: This tag is used in the generated output file (.tex).
+# See also: MAKEINDEX_CMD_NAME for the part in the Makefile / make.bat.
+# The default value is: makeindex.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_MAKEINDEX_CMD    = makeindex
+
+# If the COMPACT_LATEX tag is set to YES, doxygen generates more compact LaTeX
+# documents. This may be useful for small projects and may help to save some
+# trees in general.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+COMPACT_LATEX          = NO
+
+# The PAPER_TYPE tag can be used to set the paper type that is used by the
+# printer.
+# Possible values are: a4 (210 x 297 mm), letter (8.5 x 11 inches), legal (8.5 x
+# 14 inches) and executive (7.25 x 10.5 inches).
+# The default value is: a4.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+PAPER_TYPE             = a4
+
+# The EXTRA_PACKAGES tag can be used to specify one or more LaTeX package names
+# that should be included in the LaTeX output. The package can be specified just
+# by its name or with the correct syntax as to be used with the LaTeX
+# \usepackage command. To get the times font for instance you can specify :
+# EXTRA_PACKAGES=times or EXTRA_PACKAGES={times}
+# To use the option intlimits with the amsmath package you can specify:
+# EXTRA_PACKAGES=[intlimits]{amsmath}
+# If left blank no extra packages will be included.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+EXTRA_PACKAGES         =
+
+# The LATEX_HEADER tag can be used to specify a personal LaTeX header for the
+# generated LaTeX document. The header should contain everything until the first
+# chapter. If it is left blank doxygen will generate a standard header. See
+# section "Doxygen usage" for information on how to let doxygen write the
+# default header to a separate file.
+#
+# Note: Only use a user-defined header if you know what you are doing! The
+# following commands have a special meaning inside the header: $title,
+# $datetime, $date, $doxygenversion, $projectname, $projectnumber,
+# $projectbrief, $projectlogo. Doxygen will replace $title with the empty
+# string, for the replacement values of the other commands the user is referred
+# to HTML_HEADER.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_HEADER           =
+
+# The LATEX_FOOTER tag can be used to specify a personal LaTeX footer for the
+# generated LaTeX document. The footer should contain everything after the last
+# chapter. If it is left blank doxygen will generate a standard footer. See
+# LATEX_HEADER for more information on how to generate a default footer and what
+# special commands can be used inside the footer.
+#
+# Note: Only use a user-defined footer if you know what you are doing!
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_FOOTER           =
+
+# The LATEX_EXTRA_STYLESHEET tag can be used to specify additional user-defined
+# LaTeX style sheets that are included after the standard style sheets created
+# by doxygen. Using this option one can overrule certain style aspects. Doxygen
+# will copy the style sheet files to the output directory.
+# Note: The order of the extra style sheet files is of importance (e.g. the last
+# style sheet in the list overrules the setting of the previous ones in the
+# list).
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_EXTRA_STYLESHEET =
+
+# The LATEX_EXTRA_FILES tag can be used to specify one or more extra images or
+# other source files which should be copied to the LATEX_OUTPUT output
+# directory. Note that the files will be copied as-is; there are no commands or
+# markers available.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_EXTRA_FILES      =
+
+# If the PDF_HYPERLINKS tag is set to YES, the LaTeX that is generated is
+# prepared for conversion to PDF (using ps2pdf or pdflatex). The PDF file will
+# contain links (just like the HTML output) instead of page references. This
+# makes the output suitable for online browsing using a PDF viewer.
+# The default value is: YES.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+PDF_HYPERLINKS         = YES
+
+# If the USE_PDFLATEX tag is set to YES, doxygen will use pdflatex to generate
+# the PDF file directly from the LaTeX files. Set this option to YES, to get a
+# higher quality PDF documentation.
+# The default value is: YES.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+USE_PDFLATEX           = YES
+
+# If the LATEX_BATCHMODE tag is set to YES, doxygen will add the \batchmode
+# command to the generated LaTeX files. This will instruct LaTeX to keep running
+# if errors occur, instead of asking the user for help. This option is also used
+# when generating formulas in HTML.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_BATCHMODE        = NO
+
+# If the LATEX_HIDE_INDICES tag is set to YES then doxygen will not include the
+# index chapters (such as File Index, Compound Index, etc.) in the output.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_HIDE_INDICES     = NO
+
+# If the LATEX_SOURCE_CODE tag is set to YES then doxygen will include source
+# code with syntax highlighting in the LaTeX output.
+#
+# Note that which sources are shown also depends on other settings such as
+# SOURCE_BROWSER.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_SOURCE_CODE      = NO
+
+# The LATEX_BIB_STYLE tag can be used to specify the style to use for the
+# bibliography, e.g. plainnat, or ieeetr. See
+# https://en.wikipedia.org/wiki/BibTeX and \cite for more info.
+# The default value is: plain.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_BIB_STYLE        = plain
+
+# If the LATEX_TIMESTAMP tag is set to YES then the footer of each generated
+# page will contain the date and time when the page was generated. Setting this
+# to NO can help when comparing the output of multiple runs.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_TIMESTAMP        = NO
+
+# The LATEX_EMOJI_DIRECTORY tag is used to specify the (relative or absolute)
+# path from which the emoji images will be read. If a relative path is entered,
+# it will be relative to the LATEX_OUTPUT directory. If left blank the
+# LATEX_OUTPUT directory will be used.
+# This tag requires that the tag GENERATE_LATEX is set to YES.
+
+LATEX_EMOJI_DIRECTORY  =
+
+#---------------------------------------------------------------------------
+# Configuration options related to the RTF output
+#---------------------------------------------------------------------------
+
+# If the GENERATE_RTF tag is set to YES, doxygen will generate RTF output. The
+# RTF output is optimized for Word 97 and may not look too pretty with other RTF
+# readers/editors.
+# The default value is: NO.
+
+GENERATE_RTF           = NO
+
+# The RTF_OUTPUT tag is used to specify where the RTF docs will be put. If a
+# relative path is entered the value of OUTPUT_DIRECTORY will be put in front of
+# it.
+# The default directory is: rtf.
+# This tag requires that the tag GENERATE_RTF is set to YES.
+
+RTF_OUTPUT             = rtf
+
+# If the COMPACT_RTF tag is set to YES, doxygen generates more compact RTF
+# documents. This may be useful for small projects and may help to save some
+# trees in general.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_RTF is set to YES.
+
+COMPACT_RTF            = NO
+
+# If the RTF_HYPERLINKS tag is set to YES, the RTF that is generated will
+# contain hyperlink fields. The RTF file will contain links (just like the HTML
+# output) instead of page references. This makes the output suitable for online
+# browsing using Word or some other Word compatible readers that support those
+# fields.
+#
+# Note: WordPad (write) and others do not support links.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_RTF is set to YES.
+
+RTF_HYPERLINKS         = NO
+
+# Load stylesheet definitions from file. Syntax is similar to doxygen's
+# configuration file, i.e. a series of assignments. You only have to provide
+# replacements, missing definitions are set to their default value.
+#
+# See also section "Doxygen usage" for information on how to generate the
+# default style sheet that doxygen normally uses.
+# This tag requires that the tag GENERATE_RTF is set to YES.
+
+RTF_STYLESHEET_FILE    =
+
+# Set optional variables used in the generation of an RTF document. Syntax is
+# similar to doxygen's configuration file. A template extensions file can be
+# generated using doxygen -e rtf extensionFile.
+# This tag requires that the tag GENERATE_RTF is set to YES.
+
+RTF_EXTENSIONS_FILE    =
+
+# If the RTF_SOURCE_CODE tag is set to YES then doxygen will include source code
+# with syntax highlighting in the RTF output.
+#
+# Note that which sources are shown also depends on other settings such as
+# SOURCE_BROWSER.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_RTF is set to YES.
+
+RTF_SOURCE_CODE        = NO
+
+#---------------------------------------------------------------------------
+# Configuration options related to the man page output
+#---------------------------------------------------------------------------
+
+# If the GENERATE_MAN tag is set to YES, doxygen will generate man pages for
+# classes and files.
+# The default value is: NO.
+
+GENERATE_MAN           = NO
+
+# The MAN_OUTPUT tag is used to specify where the man pages will be put. If a
+# relative path is entered the value of OUTPUT_DIRECTORY will be put in front of
+# it. A directory man3 will be created inside the directory specified by
+# MAN_OUTPUT.
+# The default directory is: man.
+# This tag requires that the tag GENERATE_MAN is set to YES.
+
+MAN_OUTPUT             = man
+
+# The MAN_EXTENSION tag determines the extension that is added to the generated
+# man pages. In case the manual section does not start with a number, the number
+# 3 is prepended. The dot (.) at the beginning of the MAN_EXTENSION tag is
+# optional.
+# The default value is: .3.
+# This tag requires that the tag GENERATE_MAN is set to YES.
+
+MAN_EXTENSION          = .3
+
+# The MAN_SUBDIR tag determines the name of the directory created within
+# MAN_OUTPUT in which the man pages are placed. If defaults to man followed by
+# MAN_EXTENSION with the initial . removed.
+# This tag requires that the tag GENERATE_MAN is set to YES.
+
+MAN_SUBDIR             =
+
+# If the MAN_LINKS tag is set to YES and doxygen generates man output, then it
+# will generate one additional man file for each entity documented in the real
+# man page(s). These additional files only source the real man page, but without
+# them the man command would be unable to find the correct page.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_MAN is set to YES.
+
+MAN_LINKS              = NO
+
+#---------------------------------------------------------------------------
+# Configuration options related to the XML output
+#---------------------------------------------------------------------------
+
+# If the GENERATE_XML tag is set to YES, doxygen will generate an XML file that
+# captures the structure of the code including all documentation.
+# The default value is: NO.
+
+GENERATE_XML           = NO
+
+# The XML_OUTPUT tag is used to specify where the XML pages will be put. If a
+# relative path is entered the value of OUTPUT_DIRECTORY will be put in front of
+# it.
+# The default directory is: xml.
+# This tag requires that the tag GENERATE_XML is set to YES.
+
+XML_OUTPUT             = xml
+
+# If the XML_PROGRAMLISTING tag is set to YES, doxygen will dump the program
+# listings (including syntax highlighting and cross-referencing information) to
+# the XML output. Note that enabling this will significantly increase the size
+# of the XML output.
+# The default value is: YES.
+# This tag requires that the tag GENERATE_XML is set to YES.
+
+XML_PROGRAMLISTING     = YES
+
+# If the XML_NS_MEMB_FILE_SCOPE tag is set to YES, doxygen will include
+# namespace members in file scope as well, matching the HTML output.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_XML is set to YES.
+
+XML_NS_MEMB_FILE_SCOPE = NO
+
+#---------------------------------------------------------------------------
+# Configuration options related to the DOCBOOK output
+#---------------------------------------------------------------------------
+
+# If the GENERATE_DOCBOOK tag is set to YES, doxygen will generate Docbook files
+# that can be used to generate PDF.
+# The default value is: NO.
+
+GENERATE_DOCBOOK       = NO
+
+# The DOCBOOK_OUTPUT tag is used to specify where the Docbook pages will be put.
+# If a relative path is entered the value of OUTPUT_DIRECTORY will be put in
+# front of it.
+# The default directory is: docbook.
+# This tag requires that the tag GENERATE_DOCBOOK is set to YES.
+
+DOCBOOK_OUTPUT         = docbook
+
+# If the DOCBOOK_PROGRAMLISTING tag is set to YES, doxygen will include the
+# program listings (including syntax highlighting and cross-referencing
+# information) to the DOCBOOK output. Note that enabling this will significantly
+# increase the size of the DOCBOOK output.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_DOCBOOK is set to YES.
+
+DOCBOOK_PROGRAMLISTING = NO
+
+#---------------------------------------------------------------------------
+# Configuration options for the AutoGen Definitions output
+#---------------------------------------------------------------------------
+
+# If the GENERATE_AUTOGEN_DEF tag is set to YES, doxygen will generate an
+# AutoGen Definitions (see http://autogen.sourceforge.net/) file that captures
+# the structure of the code including all documentation. Note that this feature
+# is still experimental and incomplete at the moment.
+# The default value is: NO.
+
+GENERATE_AUTOGEN_DEF   = NO
+
+#---------------------------------------------------------------------------
+# Configuration options related to the Perl module output
+#---------------------------------------------------------------------------
+
+# If the GENERATE_PERLMOD tag is set to YES, doxygen will generate a Perl module
+# file that captures the structure of the code including all documentation.
+#
+# Note that this feature is still experimental and incomplete at the moment.
+# The default value is: NO.
+
+GENERATE_PERLMOD       = NO
+
+# If the PERLMOD_LATEX tag is set to YES, doxygen will generate the necessary
+# Makefile rules, Perl scripts and LaTeX code to be able to generate PDF and DVI
+# output from the Perl module output.
+# The default value is: NO.
+# This tag requires that the tag GENERATE_PERLMOD is set to YES.
+
+PERLMOD_LATEX          = NO
+
+# If the PERLMOD_PRETTY tag is set to YES, the Perl module output will be nicely
+# formatted so it can be parsed by a human reader. This is useful if you want to
+# understand what is going on. On the other hand, if this tag is set to NO, the
+# size of the Perl module output will be much smaller and Perl will parse it
+# just the same.
+# The default value is: YES.
+# This tag requires that the tag GENERATE_PERLMOD is set to YES.
+
+PERLMOD_PRETTY         = YES
+
+# The names of the make variables in the generated doxyrules.make file are
+# prefixed with the string contained in PERLMOD_MAKEVAR_PREFIX. This is useful
+# so different doxyrules.make files included by the same Makefile don't
+# overwrite each other's variables.
+# This tag requires that the tag GENERATE_PERLMOD is set to YES.
+
+PERLMOD_MAKEVAR_PREFIX =
+
+#---------------------------------------------------------------------------
+# Configuration options related to the preprocessor
+#---------------------------------------------------------------------------
+
+# If the ENABLE_PREPROCESSING tag is set to YES, doxygen will evaluate all
+# C-preprocessor directives found in the sources and include files.
+# The default value is: YES.
+
+ENABLE_PREPROCESSING   = YES
+
+# If the MACRO_EXPANSION tag is set to YES, doxygen will expand all macro names
+# in the source code. If set to NO, only conditional compilation will be
+# performed. Macro expansion can be done in a controlled way by setting
+# EXPAND_ONLY_PREDEF to YES.
+# The default value is: NO.
+# This tag requires that the tag ENABLE_PREPROCESSING is set to YES.
+
+MACRO_EXPANSION        = NO
+
+# If the EXPAND_ONLY_PREDEF and MACRO_EXPANSION tags are both set to YES then
+# the macro expansion is limited to the macros specified with the PREDEFINED and
+# EXPAND_AS_DEFINED tags.
+# The default value is: NO.
+# This tag requires that the tag ENABLE_PREPROCESSING is set to YES.
+
+EXPAND_ONLY_PREDEF     = NO
+
+# If the SEARCH_INCLUDES tag is set to YES, the include files in the
+# INCLUDE_PATH will be searched if a #include is found.
+# The default value is: YES.
+# This tag requires that the tag ENABLE_PREPROCESSING is set to YES.
+
+SEARCH_INCLUDES        = YES
+
+# The INCLUDE_PATH tag can be used to specify one or more directories that
+# contain include files that are not input files but should be processed by the
+# preprocessor.
+# This tag requires that the tag SEARCH_INCLUDES is set to YES.
+
+INCLUDE_PATH           =
+
+# You can use the INCLUDE_FILE_PATTERNS tag to specify one or more wildcard
+# patterns (like *.h and *.hpp) to filter out the header-files in the
+# directories. If left blank, the patterns specified with FILE_PATTERNS will be
+# used.
+# This tag requires that the tag ENABLE_PREPROCESSING is set to YES.
+
+INCLUDE_FILE_PATTERNS  =
+
+# The PREDEFINED tag can be used to specify one or more macro names that are
+# defined before the preprocessor is started (similar to the -D option of e.g.
+# gcc). The argument of the tag is a list of macros of the form: name or
+# name=definition (no spaces). If the definition and the "=" are omitted, "=1"
+# is assumed. To prevent a macro definition from being undefined via #undef or
+# recursively expanded use the := operator instead of the = operator.
+# This tag requires that the tag ENABLE_PREPROCESSING is set to YES.
+
+PREDEFINED             =
+
+# If the MACRO_EXPANSION and EXPAND_ONLY_PREDEF tags are set to YES then this
+# tag can be used to specify a list of macro names that should be expanded. The
+# macro definition that is found in the sources will be used. Use the PREDEFINED
+# tag if you want to use a different macro definition that overrules the
+# definition found in the source code.
+# This tag requires that the tag ENABLE_PREPROCESSING is set to YES.
+
+EXPAND_AS_DEFINED      =
+
+# If the SKIP_FUNCTION_MACROS tag is set to YES then doxygen's preprocessor will
+# remove all references to function-like macros that are alone on a line, have
+# an all uppercase name, and do not end with a semicolon. Such function macros
+# are typically used for boiler-plate code, and will confuse the parser if not
+# removed.
+# The default value is: YES.
+# This tag requires that the tag ENABLE_PREPROCESSING is set to YES.
+
+SKIP_FUNCTION_MACROS   = YES
+
+#---------------------------------------------------------------------------
+# Configuration options related to external references
+#---------------------------------------------------------------------------
+
+# The TAGFILES tag can be used to specify one or more tag files. For each tag
+# file the location of the external documentation should be added. The format of
+# a tag file without this location is as follows:
+# TAGFILES = file1 file2 ...
+# Adding location for the tag files is done as follows:
+# TAGFILES = file1=loc1 "file2 = loc2" ...
+# where loc1 and loc2 can be relative or absolute paths or URLs. See the
+# section "Linking to external documentation" for more information about the use
+# of tag files.
+# Note: Each tag file must have a unique name (where the name does NOT include
+# the path). If a tag file is not located in the directory in which doxygen is
+# run, you must also specify the path to the tagfile here.
+
+TAGFILES               =
+
+# When a file name is specified after GENERATE_TAGFILE, doxygen will create a
+# tag file that is based on the input files it reads. See section "Linking to
+# external documentation" for more information about the usage of tag files.
+
+GENERATE_TAGFILE       =
+
+# If the ALLEXTERNALS tag is set to YES, all external class will be listed in
+# the class index. If set to NO, only the inherited external classes will be
+# listed.
+# The default value is: NO.
+
+ALLEXTERNALS           = NO
+
+# If the EXTERNAL_GROUPS tag is set to YES, all external groups will be listed
+# in the modules index. If set to NO, only the current project's groups will be
+# listed.
+# The default value is: YES.
+
+EXTERNAL_GROUPS        = YES
+
+# If the EXTERNAL_PAGES tag is set to YES, all external pages will be listed in
+# the related pages index. If set to NO, only the current project's pages will
+# be listed.
+# The default value is: YES.
+
+EXTERNAL_PAGES         = YES
+
+#---------------------------------------------------------------------------
+# Configuration options related to the dot tool
+#---------------------------------------------------------------------------
+
+# If the CLASS_DIAGRAMS tag is set to YES, doxygen will generate a class diagram
+# (in HTML and LaTeX) for classes with base or super classes. Setting the tag to
+# NO turns the diagrams off. Note that this option also works with HAVE_DOT
+# disabled, but it is recommended to install and use dot, since it yields more
+# powerful graphs.
+# The default value is: YES.
+
+CLASS_DIAGRAMS         = YES
+
+# You can include diagrams made with dia in doxygen documentation. Doxygen will
+# then run dia to produce the diagram and insert it in the documentation. The
+# DIA_PATH tag allows you to specify the directory where the dia binary resides.
+# If left empty dia is assumed to be found in the default search path.
+
+DIA_PATH               =
+
+# If set to YES the inheritance and collaboration graphs will hide inheritance
+# and usage relations if the target is undocumented or is not a class.
+# The default value is: YES.
+
+HIDE_UNDOC_RELATIONS   = YES
+
+# If you set the HAVE_DOT tag to YES then doxygen will assume the dot tool is
+# available from the path. This tool is part of Graphviz (see:
+# http://www.graphviz.org/), a graph visualization toolkit from AT&T and Lucent
+# Bell Labs. The other options in this section have no effect if this option is
+# set to NO
+# The default value is: YES.
+
+HAVE_DOT               = YES
+
+# The DOT_NUM_THREADS specifies the number of dot invocations doxygen is allowed
+# to run in parallel. When set to 0 doxygen will base this on the number of
+# processors available in the system. You can set it explicitly to a value
+# larger than 0 to get control over the balance between CPU load and processing
+# speed.
+# Minimum value: 0, maximum value: 32, default value: 0.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+DOT_NUM_THREADS        = 0
+
+# When you want a differently looking font in the dot files that doxygen
+# generates you can specify the font name using DOT_FONTNAME. You need to make
+# sure dot is able to find the font, which can be done by putting it in a
+# standard location or by setting the DOTFONTPATH environment variable or by
+# setting DOT_FONTPATH to the directory containing the font.
+# The default value is: Helvetica.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+DOT_FONTNAME           = Helvetica
+
+# The DOT_FONTSIZE tag can be used to set the size (in points) of the font of
+# dot graphs.
+# Minimum value: 4, maximum value: 24, default value: 10.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+DOT_FONTSIZE           = 10
+
+# By default doxygen will tell dot to use the default font as specified with
+# DOT_FONTNAME. If you specify a different font using DOT_FONTNAME you can set
+# the path where dot can find it using this tag.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+DOT_FONTPATH           =
+
+# If the CLASS_GRAPH tag is set to YES then doxygen will generate a graph for
+# each documented class showing the direct and indirect inheritance relations.
+# Setting this tag to YES will force the CLASS_DIAGRAMS tag to NO.
+# The default value is: YES.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+CLASS_GRAPH            = YES
+
+# If the COLLABORATION_GRAPH tag is set to YES then doxygen will generate a
+# graph for each documented class showing the direct and indirect implementation
+# dependencies (inheritance, containment, and class references variables) of the
+# class with other documented classes.
+# The default value is: YES.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+COLLABORATION_GRAPH    = YES
+
+# If the GROUP_GRAPHS tag is set to YES then doxygen will generate a graph for
+# groups, showing the direct groups dependencies.
+# The default value is: YES.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+GROUP_GRAPHS           = YES
+
+# If the UML_LOOK tag is set to YES, doxygen will generate inheritance and
+# collaboration diagrams in a style similar to the OMG's Unified Modeling
+# Language.
+# The default value is: NO.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+UML_LOOK               = NO
+
+# If the UML_LOOK tag is enabled, the fields and methods are shown inside the
+# class node. If there are many fields or methods and many nodes the graph may
+# become too big to be useful. The UML_LIMIT_NUM_FIELDS threshold limits the
+# number of items for each type to make the size more manageable. Set this to 0
+# for no limit. Note that the threshold may be exceeded by 50% before the limit
+# is enforced. So when you set the threshold to 10, up to 15 fields may appear,
+# but if the number exceeds 15, the total amount of fields shown is limited to
+# 10.
+# Minimum value: 0, maximum value: 100, default value: 10.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+UML_LIMIT_NUM_FIELDS   = 10
+
+# If the TEMPLATE_RELATIONS tag is set to YES then the inheritance and
+# collaboration graphs will show the relations between templates and their
+# instances.
+# The default value is: NO.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+TEMPLATE_RELATIONS     = NO
+
+# If the INCLUDE_GRAPH, ENABLE_PREPROCESSING and SEARCH_INCLUDES tags are set to
+# YES then doxygen will generate a graph for each documented file showing the
+# direct and indirect include dependencies of the file with other documented
+# files.
+# The default value is: YES.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+INCLUDE_GRAPH          = YES
+
+# If the INCLUDED_BY_GRAPH, ENABLE_PREPROCESSING and SEARCH_INCLUDES tags are
+# set to YES then doxygen will generate a graph for each documented file showing
+# the direct and indirect include dependencies of the file with other documented
+# files.
+# The default value is: YES.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+INCLUDED_BY_GRAPH      = YES
+
+# If the CALL_GRAPH tag is set to YES then doxygen will generate a call
+# dependency graph for every global function or class method.
+#
+# Note that enabling this option will significantly increase the time of a run.
+# So in most cases it will be better to enable call graphs for selected
+# functions only using the \callgraph command. Disabling a call graph can be
+# accomplished by means of the command \hidecallgraph.
+# The default value is: NO.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+CALL_GRAPH             = NO
+
+# If the CALLER_GRAPH tag is set to YES then doxygen will generate a caller
+# dependency graph for every global function or class method.
+#
+# Note that enabling this option will significantly increase the time of a run.
+# So in most cases it will be better to enable caller graphs for selected
+# functions only using the \callergraph command. Disabling a caller graph can be
+# accomplished by means of the command \hidecallergraph.
+# The default value is: NO.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+CALLER_GRAPH           = NO
+
+# If the GRAPHICAL_HIERARCHY tag is set to YES then doxygen will graphical
+# hierarchy of all classes instead of a textual one.
+# The default value is: YES.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+GRAPHICAL_HIERARCHY    = YES
+
+# If the DIRECTORY_GRAPH tag is set to YES then doxygen will show the
+# dependencies a directory has on other directories in a graphical way. The
+# dependency relations are determined by the #include relations between the
+# files in the directories.
+# The default value is: YES.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+DIRECTORY_GRAPH        = YES
+
+# The DOT_IMAGE_FORMAT tag can be used to set the image format of the images
+# generated by dot. For an explanation of the image formats see the section
+# output formats in the documentation of the dot tool (Graphviz (see:
+# http://www.graphviz.org/)).
+# Note: If you choose svg you need to set HTML_FILE_EXTENSION to xhtml in order
+# to make the SVG files visible in IE 9+ (other browsers do not have this
+# requirement).
+# Possible values are: png, png:cairo, png:cairo:cairo, png:cairo:gd, png:gd,
+# png:gd:gd, jpg, jpg:cairo, jpg:cairo:gd, jpg:gd, jpg:gd:gd, gif, gif:cairo,
+# gif:cairo:gd, gif:gd, gif:gd:gd, svg, png:gd, png:gd:gd, png:cairo,
+# png:cairo:gd, png:cairo:cairo, png:cairo:gdiplus, png:gdiplus and
+# png:gdiplus:gdiplus.
+# The default value is: png.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+DOT_IMAGE_FORMAT       = png
+
+# If DOT_IMAGE_FORMAT is set to svg, then this option can be set to YES to
+# enable generation of interactive SVG images that allow zooming and panning.
+#
+# Note that this requires a modern browser other than Internet Explorer. Tested
+# and working are Firefox, Chrome, Safari, and Opera.
+# Note: For IE 9+ you need to set HTML_FILE_EXTENSION to xhtml in order to make
+# the SVG files visible. Older versions of IE do not have SVG support.
+# The default value is: NO.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+INTERACTIVE_SVG        = NO
+
+# The DOT_PATH tag can be used to specify the path where the dot tool can be
+# found. If left blank, it is assumed the dot tool can be found in the path.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+DOT_PATH               =
+
+# The DOTFILE_DIRS tag can be used to specify one or more directories that
+# contain dot files that are included in the documentation (see the \dotfile
+# command).
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+DOTFILE_DIRS           =
+
+# The MSCFILE_DIRS tag can be used to specify one or more directories that
+# contain msc files that are included in the documentation (see the \mscfile
+# command).
+
+MSCFILE_DIRS           =
+
+# The DIAFILE_DIRS tag can be used to specify one or more directories that
+# contain dia files that are included in the documentation (see the \diafile
+# command).
+
+DIAFILE_DIRS           =
+
+# When using plantuml, the PLANTUML_JAR_PATH tag should be used to specify the
+# path where java can find the plantuml.jar file. If left blank, it is assumed
+# PlantUML is not used or called during a preprocessing step. Doxygen will
+# generate a warning when it encounters a \startuml command in this case and
+# will not generate output for the diagram.
+
+PLANTUML_JAR_PATH      =
+
+# When using plantuml, the PLANTUML_CFG_FILE tag can be used to specify a
+# configuration file for plantuml.
+
+PLANTUML_CFG_FILE      =
+
+# When using plantuml, the specified paths are searched for files specified by
+# the !include statement in a plantuml block.
+
+PLANTUML_INCLUDE_PATH  =
+
+# The DOT_GRAPH_MAX_NODES tag can be used to set the maximum number of nodes
+# that will be shown in the graph. If the number of nodes in a graph becomes
+# larger than this value, doxygen will truncate the graph, which is visualized
+# by representing a node as a red box. Note that doxygen if the number of direct
+# children of the root node in a graph is already larger than
+# DOT_GRAPH_MAX_NODES then the graph will not be shown at all. Also note that
+# the size of a graph can be further restricted by MAX_DOT_GRAPH_DEPTH.
+# Minimum value: 0, maximum value: 10000, default value: 50.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+DOT_GRAPH_MAX_NODES    = 50
+
+# The MAX_DOT_GRAPH_DEPTH tag can be used to set the maximum depth of the graphs
+# generated by dot. A depth value of 3 means that only nodes reachable from the
+# root by following a path via at most 3 edges will be shown. Nodes that lay
+# further from the root node will be omitted. Note that setting this option to 1
+# or 2 may greatly reduce the computation time needed for large code bases. Also
+# note that the size of a graph can be further restricted by
+# DOT_GRAPH_MAX_NODES. Using a depth of 0 means no depth restriction.
+# Minimum value: 0, maximum value: 1000, default value: 0.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+MAX_DOT_GRAPH_DEPTH    = 0
+
+# Set the DOT_TRANSPARENT tag to YES to generate images with a transparent
+# background. This is disabled by default, because dot on Windows does not seem
+# to support this out of the box.
+#
+# Warning: Depending on the platform used, enabling this option may lead to
+# badly anti-aliased labels on the edges of a graph (i.e. they become hard to
+# read).
+# The default value is: NO.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+DOT_TRANSPARENT        = NO
+
+# Set the DOT_MULTI_TARGETS tag to YES to allow dot to generate multiple output
+# files in one run (i.e. multiple -o and -T options on the command line). This
+# makes dot run faster, but since only newer versions of dot (>1.8.10) support
+# this, this feature is disabled by default.
+# The default value is: NO.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+DOT_MULTI_TARGETS      = NO
+
+# If the GENERATE_LEGEND tag is set to YES doxygen will generate a legend page
+# explaining the meaning of the various boxes and arrows in the dot generated
+# graphs.
+# The default value is: YES.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+GENERATE_LEGEND        = YES
+
+# If the DOT_CLEANUP tag is set to YES, doxygen will remove the intermediate dot
+# files that are used to generate the various graphs.
+# The default value is: YES.
+# This tag requires that the tag HAVE_DOT is set to YES.
+
+DOT_CLEANUP            = YES
diff --git a/src/ggml-cann/acl_tensor.cpp b/src/ggml-cann/acl_tensor.cpp
new file mode 100644 (file)
index 0000000..d120ce6
--- /dev/null
@@ -0,0 +1,175 @@
+/*
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+ * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+ * IN THE SOFTWARE.
+ */
+
+#include "acl_tensor.h"
+
+#include <algorithm>
+#include <cstring>
+
+aclDataType ggml_cann_type_mapping(ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_F32:
+            return ACL_FLOAT;
+        case GGML_TYPE_F16:
+            return ACL_FLOAT16;
+        case GGML_TYPE_I8:
+            return ACL_INT8;
+        case GGML_TYPE_I16:
+            return ACL_INT16;
+        case GGML_TYPE_I32:
+            return ACL_INT32;
+        case GGML_TYPE_Q4_0:
+            return ACL_INT4;
+        case GGML_TYPE_Q8_0:
+            return ACL_INT8;
+        default:
+            return ACL_DT_UNDEFINED;
+    }
+    return ACL_DT_UNDEFINED;
+}
+
+aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne,
+                                   size_t* nb, int64_t dims, aclFormat format,
+                                   size_t offset) {
+    // If tensor is bcasted, Up to GGML_MAX_DIMS additional dimensions will be
+    // added.
+    int64_t acl_ne[GGML_MAX_DIMS * 2], acl_stride[GGML_MAX_DIMS * 2];
+
+    int64_t acl_storage_len = 0;
+    if (ne == nullptr) {
+        acl_storage_len = ggml_nbytes(tensor);
+        for (int i = 0; i < GGML_MAX_DIMS; i++) {
+            acl_ne[i] = tensor->ne[i];
+            // The step size of acl is in elements.
+            acl_stride[i] = tensor->nb[i] / ggml_element_size(tensor);
+        }
+    } else {
+        // With bcast
+        for (int i = 0; i < dims; i++) {
+            acl_storage_len += (ne[i] - 1) * nb[i];
+            acl_ne[i] = ne[i];
+            acl_stride[i] = nb[i] / ggml_element_size(tensor);
+        }
+    }
+
+    // Reverse ne and stride.
+    int64_t final_dims = (dims == 0 ? GGML_MAX_DIMS : dims);
+    std::reverse(acl_ne, acl_ne + final_dims);
+    std::reverse(acl_stride, acl_stride + final_dims);
+
+    aclTensor* acl_tensor = aclCreateTensor(
+        acl_ne, final_dims, ggml_cann_type_mapping(tensor->type), acl_stride,
+        offset / ggml_element_size(tensor), format, &acl_storage_len, 1,
+        tensor->data);
+
+    return acl_tensor;
+}
+
+bool ggml_cann_need_bcast(const ggml_tensor* t0, const ggml_tensor* t1) {
+    for (int i = 0; i < GGML_MAX_DIMS; i++) {
+        if (t1->ne[i] != t0->ne[i] && t1->ne[i] != 1) {
+            return true;
+        }
+    }
+    return false;
+}
+
+int64_t ggml_cann_get_bcast_shape(const ggml_tensor* src0,
+                                  const ggml_tensor* src1,
+                                  int64_t* bcast_src0_ne,
+                                  int64_t* bcast_src1_ne, size_t* bcast_src0_nb,
+                                  size_t* bcast_src1_nb) {
+    GGML_ASSERT(ggml_can_repeat(src1, src0));
+    int bcast_dim_cnt = 0;
+    for (int i = 0; i < GGML_MAX_DIMS; i++) {
+        int64_t nr = src0->ne[i] / src1->ne[i];
+        bcast_src0_ne[bcast_dim_cnt] = src0->ne[i] / nr;
+        bcast_src1_ne[bcast_dim_cnt] = src1->ne[i];
+        bcast_src0_nb[bcast_dim_cnt] = src0->nb[i];
+        bcast_src1_nb[bcast_dim_cnt] = src1->nb[i];
+        bcast_dim_cnt++;
+        if (nr != 1) {
+            // Need to add an extra dim.
+            bcast_src0_ne[bcast_dim_cnt] = nr;
+            bcast_src1_ne[bcast_dim_cnt] = 1;
+            bcast_src0_nb[bcast_dim_cnt] = bcast_src0_nb[bcast_dim_cnt - 1] *
+                                           bcast_src0_ne[bcast_dim_cnt - 1];
+            bcast_src1_nb[bcast_dim_cnt] = bcast_src1_nb[bcast_dim_cnt - 1] *
+                                           bcast_src1_ne[bcast_dim_cnt - 1];
+            bcast_dim_cnt++;
+        }
+    }
+    return bcast_dim_cnt;
+}
+
+int64_t ggml_cann_get_mulmat_bcast_shape(
+    const int64_t* input_ne, const int64_t* weight_ne, const int64_t* dst_ne,
+    const size_t* input_nb, const size_t* weight_nb, const size_t* dst_nb,
+    int64_t* bcast_input_ne, int64_t* bcast_weight_ne, int64_t* bcast_dst_ne,
+    size_t* bcast_input_nb, size_t* bcast_weight_nb, size_t* bcast_dst_nb) {
+    // input and dst shoule in same shape, except first two dims.
+    GGML_ASSERT(input_ne[2] == dst_ne[2]);
+    GGML_ASSERT(input_ne[3] == dst_ne[3]);
+
+    int bcast_dim_cnt = 0;
+
+    // For mul_mat, a dimension needs to be added before the dimension that
+    // weight needs to be expanded to satisfy the bcast rule of matrix
+    // multiplication.
+    for (int i = 0; i < GGML_MAX_DIMS; i++) {
+        int64_t nr = input_ne[i] / weight_ne[i];
+        // Do not use bcast in the first two dimensions because we only support
+        // the bcast batch dimension. Just copy them.
+        if (i < 2 || nr == 1) {
+            bcast_input_ne[bcast_dim_cnt] = input_ne[i];
+            bcast_weight_ne[bcast_dim_cnt] = weight_ne[i];
+            bcast_dst_ne[bcast_dim_cnt] = dst_ne[i];
+
+            bcast_input_nb[bcast_dim_cnt] = input_nb[i];
+            bcast_weight_nb[bcast_dim_cnt] = weight_nb[i];
+            bcast_dst_nb[bcast_dim_cnt] = dst_nb[i];
+            bcast_dim_cnt++;
+        } else {
+            // Need to add an extra dim.
+            bcast_input_ne[bcast_dim_cnt] = nr;
+            bcast_dst_ne[bcast_dim_cnt] = nr;
+            bcast_weight_ne[bcast_dim_cnt] = 1;
+            bcast_input_nb[bcast_dim_cnt] = input_nb[i];
+            bcast_dst_nb[bcast_dim_cnt] = dst_nb[i];
+            bcast_weight_nb[bcast_dim_cnt] = weight_nb[i];
+            bcast_dim_cnt++;
+
+            bcast_input_ne[bcast_dim_cnt] = input_ne[i] / nr;
+            bcast_dst_ne[bcast_dim_cnt] = dst_ne[i] / nr;
+            bcast_weight_ne[bcast_dim_cnt] = weight_ne[i];
+            bcast_input_nb[bcast_dim_cnt] = bcast_input_nb[bcast_dim_cnt - 1] *
+                                            bcast_input_ne[bcast_dim_cnt - 1];
+            bcast_dst_nb[bcast_dim_cnt] = bcast_dst_nb[bcast_dim_cnt - 1] *
+                                          bcast_dst_ne[bcast_dim_cnt - 1];
+            bcast_weight_nb[bcast_dim_cnt] =
+                bcast_weight_nb[bcast_dim_cnt - 1] *
+                bcast_weight_ne[bcast_dim_cnt - 1];
+            bcast_dim_cnt++;
+        }
+    }
+    return bcast_dim_cnt;
+}
diff --git a/src/ggml-cann/acl_tensor.h b/src/ggml-cann/acl_tensor.h
new file mode 100644 (file)
index 0000000..4734a9c
--- /dev/null
@@ -0,0 +1,258 @@
+/*
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+ * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+ * IN THE SOFTWARE.
+ */
+
+#ifndef CANN_ACL_TENSOR_H
+#define CANN_ACL_TENSOR_H
+
+#include <algorithm>
+#include <cstring>
+
+#include <aclnn/aclnn_base.h>
+#include "common.h"
+
+/**
+ * @brief      Maps a ggml_type to its corresponding aclDataType.
+ *
+ * @details    This function takes a ggml_type as input and returns the corresponding
+ *                     aclDataType. It supports mapping for various ggml_types. If the input type
+ *                     does not match any of the predefined ggml_types, the function returns
+ *          ACL_DT_UNDEFINED.
+ *
+ * @param      type    The ggml_type to be mapped.
+ * @return     The corresponding aclDataType. If the input type is not recognized,
+ *                     ACL_DT_UNDEFINED is returned.
+ */
+aclDataType ggml_cann_type_mapping(ggml_type type);
+
+/**
+ * @brief   Creates an ACL tensor from a ggml_tensor with optional shape.
+ *
+ * @details This function creates an ACL tensor based on the properties of the
+ *          provided ggml_tensor. It supports customer shape by adjusting dimensions
+ *          and strides accordingly. If customer shape is applied, additional
+ *          dimensions and strides are calculated based on the provided parameters.
+ *
+ * @param   tensor      Pointer to the ggml_tensor to be converted to ACL tensor.
+ * @param   ne          Pointer to an array containing dimensions. Defaults to nullptr
+ *                      if no customer shape is applied.
+ * @param   nb          Pointer to an array containing strides. Defaults to nullptr
+ *                      if no customer shape is applied.
+ * @param   dims        Number of dimensions in the tensor. Defaults to 0 if no customer
+ *                      shape is applied.
+ * @param   format      ACL tensor format. Defaults to ACL_FORMAT_ND.
+ * @param   offset      Offset in bytes for the ACL tensor data. Defaults to 0.
+ * @return  Pointer to the created ACL tensor.
+ */
+aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne = nullptr,
+                             size_t* nb = nullptr, int64_t dims = 0,
+                             aclFormat format = ACL_FORMAT_ND,
+                             size_t offset = 0);
+
+/**
+ * @brief   Template for creating an ACL tensor from provided parameters. typename TYPE
+ *          should be size_t or float.
+ *
+ * @details This function creates an ACL tensor using the provided data pointer,
+ *          data type, dimensions, strides, format, offset, and additional parameters.
+ *          It calculates necessary dimensions and strides based on the provided ne and nb
+ *          arrays, adjusting them for the ACL tensor creation. The ACL storage length
+ *          is also calculated based on the provided dimensions and strides.
+ *
+ * @param   data_ptr    Pointer to the data buffer for the ACL tensor.
+ * @param   dtype       ACL data type of the tensor.
+ * @param   type_size   Size of each element in the tensor data buffer.
+ * @param   ne          Pointer to an array containing tensor dimensions.
+ * @param   nb          Pointer to an array containing tensor strides.
+ * @param   dims        Number of dimensions of the tensor.
+ * @param   format      ACL tensor format. Defaults to ACL_FORMAT_ND.
+ * @param   offset      Offset in bytes for the ACL tensor data. Defaults to 0.
+ * @return  Pointer to the created ACL tensor.
+ */
+template<typename TYPE>
+aclTensor* ggml_cann_create_tensor(void* data_ptr, aclDataType dtype,
+                                   TYPE type_size, int64_t* ne, TYPE* nb,
+                                   int64_t dims,
+                                   aclFormat format = ACL_FORMAT_ND,
+                                   size_t offset = 0) {
+    int64_t tmp_ne[GGML_MAX_DIMS * 2];
+    int64_t tmp_stride[GGML_MAX_DIMS * 2];
+
+    memcpy(tmp_ne, ne, dims * sizeof(int64_t));
+    for (int i = 0; i < dims; i++) {
+        tmp_stride[i] = nb[i] / type_size;
+    }
+
+    std::reverse(tmp_ne, tmp_ne + dims);
+    std::reverse(tmp_stride, tmp_stride + dims);
+
+    int64_t acl_storage_len = 0;
+    for (int i = 0; i < dims; i++) {
+        acl_storage_len += (ne[i] - 1) * nb[i];
+    }
+
+    aclTensor* acl_tensor =
+        aclCreateTensor(tmp_ne, dims, dtype, tmp_stride, offset / type_size,
+                        format, &acl_storage_len, 1, data_ptr);
+
+    return acl_tensor;
+}
+
+/**
+ * @brief   Checks if tensors require broadcasting based on their shapes.
+ *
+ * @details This function determines if two ggml_tensors need to be broadcasted for
+ *          element-wise operations. Broadcasting is necessary if the shapes of the
+ *          tensors are not identical and no dimension in either tensor equals 1.
+ *
+ * @param   t0      Pointer to the first ggml_tensor.
+ * @param   t1      Pointer to the second ggml_tensor.
+ * @return  True if broadcasting is needed, False otherwise.
+ *
+ * @remarks This function iterates over the dimensions of t0 and t1. It checks if each
+ *          dimension in t1 differs from t0's corresponding dimension and is not equal
+ *          to 1. If such a dimension is found, broadcasting is required to align t1
+ *          with t0 for element-wise operations.
+ */
+bool ggml_cann_need_bcast(const ggml_tensor* t0, const ggml_tensor* t1);
+
+/**
+ * @brief   Computes broadcast shapes and strides for two ggml_tensors.
+ *
+ * @details This function calculates the broadcast shapes and strides for two ggml_tensors,
+ *          following the broadcasting rules similar to numpy. It adjusts dimensions and
+ *          strides to ensure compatibility for element-wise operations where one tensor
+ *          can be broadcasted to match the shape of another tensor.
+ *
+ * @param   src0                Pointer to the first ggml_tensor.
+ * @param   src1                Pointer to the second ggml_tensor.
+ * @param   bcast_ne_src0       Output array to store broadcasted dimensions for src0.
+ * @param   bcast_ne_src1       Output array to store broadcasted dimensions for src1.
+ * @param   bcast_nb_src0       Output array to store broadcasted strides for src0.
+ * @param   bcast_nb_src1       Output array to store broadcasted strides for src1.
+ * @return  Number of dimensions in the broadcasted shape.
+ *
+ * @pre     ggml_can_repeat(src1, src0) must return true, indicating src1 can be broadcasted
+ *          to match src0.
+ *
+ * @remarks This function iterates over the dimensions of src0 and src1, calculating the
+ *          necessary broadcast dimensions and strides. If a dimension requires broadcasting
+ *          (i.e., its size in src1 is smaller than in src0), an additional dimension is
+ *          added with size calculated to match src0's dimension. This adjustment ensures
+ *          that src1 can be element-wise broadcasted to src0's shape.
+ *
+ *  How it works:
+ *
+ *  if dim0 has padding.
+ *  a -> (2, 2) padding = 2
+ *   a: [[1, 2, *, *]
+ *       [2, 3, *, *]]
+ *  nb = (8, 4, 2)
+ *
+ *  if a should bcast with b -> (2, 4)
+ *  b' -> (2, 2, 2)
+ *  b : [[1, 2, 3, 4, *, *]
+ *       [5, 6, 7, 8, *, *]]
+ *  nb = (12, 6, 1)
+ *
+ *  after bcast:
+ *  a' -> (2, 1, 2)
+ *  a': [[[1, 2], *, *]
+ *       [[2, 3], *, *]]
+ *  nb = (8, 4, 2, 1)
+ *
+ *  b' : [[[1, 2], [3, 4], *, *]
+ *        [[5, 6], [7, 8], *, *]]
+ *  nb = (12, 6, 2, 1)
+ *  \endcode
+ *
+ *  dim1 in a inserted dim, should add nb for dim1,
+ *  and all other nb moves to next in order.
+ */
+int64_t ggml_cann_get_bcast_shape(const ggml_tensor* src0, const ggml_tensor* src1,
+                        int64_t* bcast_ne_src0, int64_t* bcast_ne_src1,
+                        size_t* bcast_nb_src0, size_t* bcast_nb_src1);
+
+// Bcast macro to avoid duplicate code.
+#define BCAST_SHAPE(src0, src1)                                              \
+    int64_t bcast_##src0##_ne[GGML_MAX_DIMS * 2];                            \
+    int64_t bcast_##src1##_ne[GGML_MAX_DIMS * 2];                            \
+    size_t bcast_##src0##_nb[GGML_MAX_DIMS * 2];                             \
+    size_t bcast_##src1##_nb[GGML_MAX_DIMS * 2];                             \
+    int64_t bcast_dims = ggml_cann_get_bcast_shape(                          \
+        src0, src1, bcast_##src0##_ne, bcast_##src1##_ne, bcast_##src0##_nb, \
+        bcast_##src1##_nb);
+
+#define BCAST_PARAM(tensor) bcast_##tensor##_ne, bcast_##tensor##_nb, bcast_dims
+
+/**
+ * @brief Calculates broadcast shapes for matrix multiplication.
+ *
+ * @details This function computes the broadcast shapes required for matrix multiplication
+ *          based on the input, weight, and destination tensor shapes. It ensures that the
+ *          dimensions of weight tensors are expanded appropriately to satisfy matrix
+ *          multiplication broadcast rules.
+ *
+ * @param input_ne      Array containing the dimensions of the input tensor.
+ * @param weight_ne     Array containing the dimensions of the weight tensor.
+ * @param dst_ne        Array containing the dimensions of the destination tensor.
+ * @param input_nb      Array containing the strides of the input tensor.
+ * @param weight_nb     Array containing the strides of the weight tensor.
+ * @param dst_nb        Array containing the strides of the destination tensor.
+ * @param bcast_input_ne    Output array for broadcasted input tensor dimensions.
+ * @param bcast_weight_ne   Output array for broadcasted weight tensor dimensions.
+ * @param bcast_dst_ne      Output array for broadcasted destination tensor dimensions.
+ * @param bcast_input_nb    Output array for broadcasted input tensor strides.
+ * @param bcast_weight_nb   Output array for broadcasted weight tensor strides.
+ * @param bcast_dst_nb      Output array for broadcasted destination tensor strides.
+ * @return The number of dimensions in the broadcasted tensors.
+ *
+ * @remarks This function iterates over the tensor dimensions and calculates the broadcast
+ *          shapes needed for matrix multiplication. It ensures that dimensions where
+ *          weight tensor requires expansion are appropriately handled to conform with
+ *          broadcasting rules.
+ * @note compare with ggml_cann_get_bcast_shape, mul_mat broadcast need add this new dim
+ *       before cast dim.
+ * @sa ggml_cann_get_bcast_shape
+ */
+int64_t ggml_cann_get_mulmat_bcast_shape(
+    const int64_t* input_ne, const int64_t* weight_ne, const int64_t* dst_ne,
+    const size_t* input_nb, const size_t* weight_nb, const size_t* dst_nb,
+    int64_t* bcast_input_ne, int64_t* bcast_weight_ne, int64_t* bcast_dst_ne,
+    size_t* bcast_input_nb, size_t* bcast_weight_nb, size_t* bcast_dst_nb);
+
+// Bcast macro to avoid duplicate code.
+#define BCAST_MUL_MAT_SHAPE(input, weight, dst)                         \
+    int64_t bcast_##input##_ne[GGML_MAX_DIMS * 2];                      \
+    int64_t bcast_##weight##_ne[GGML_MAX_DIMS * 2];                     \
+    int64_t bcast_##dst##_ne[GGML_MAX_DIMS * 2];                        \
+    size_t bcast_##input##_nb[GGML_MAX_DIMS * 2];                       \
+    size_t bcast_##weight##_nb[GGML_MAX_DIMS * 2];                      \
+    size_t bcast_##dst##_nb[GGML_MAX_DIMS * 2];                         \
+    int64_t bcast_dims = ggml_cann_get_mulmat_bcast_shape(              \
+        input->ne, weight->ne, dst->ne, input->nb, weight->nb, dst->nb, \
+        bcast_##input##_ne, bcast_##weight##_ne, bcast_##dst##_ne,      \
+        bcast_##input##_nb, bcast_##weight##_nb, bcast_##dst##_nb);
+
+#define BCAST_MUL_MAT_PARAM(tensor) \
+    bcast_##tensor##_ne, bcast_##tensor##_nb, bcast_dims
+
+#endif  // CANN_ACL_TENSOR_H
diff --git a/src/ggml-cann/aclnn_ops.cpp b/src/ggml-cann/aclnn_ops.cpp
new file mode 100644 (file)
index 0000000..8c4132f
--- /dev/null
@@ -0,0 +1,3082 @@
+/*
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+ * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+ * IN THE SOFTWARE.
+ */
+
+#include "aclnn_ops.h"
+
+#include <aclnnop/aclnn_avgpool2d.h>
+#include <aclnnop/aclnn_cast.h>
+#include <aclnnop/aclnn_constant_pad_nd.h>
+#include <aclnnop/aclnn_copy.h>
+#include <aclnnop/aclnn_cos.h>
+#include <aclnnop/aclnn_exp.h>
+#include <aclnnop/aclnn_fill_scalar.h>
+#include <aclnnop/aclnn_group_norm.h>
+#include <aclnnop/aclnn_index_fill_tensor.h>
+#include <aclnnop/aclnn_layer_norm.h>
+#include <aclnnop/aclnn_matmul.h>
+#include <aclnnop/aclnn_max_pool.h>
+#include <aclnnop/aclnn_permute.h>
+#include <aclnnop/aclnn_pow_tensor_tensor.h>
+#include <aclnnop/aclnn_reduce_sum.h>
+#include <aclnnop/aclnn_repeat.h>
+#include <aclnnop/aclnn_repeat_interleave.h>
+#include <aclnnop/aclnn_roll.h>
+#include <aclnnop/aclnn_sin.h>
+#include <aclnnop/aclnn_softmax.h>
+#include <aclnnop/aclnn_tril.h>
+#include <aclnnop/aclnn_triu.h>
+#include <aclnnop/aclnn_upsample_nearest_2d.h>
+#include <aclnnop/aclnn_weight_quant_batch_matmul_v2.h>
+#include <float.h>
+
+#include <cmath>
+#include <cstring>
+#include <exception>
+#include <vector>
+
+#include "kernels/ascendc_kernels.h"
+
+#define GGML_COMMON_DECL_C
+
+#include "../ggml-common.h"
+
+/**
+ * @brief Repeats elements of a tensor along each dimension according to the
+ * specified repeat array.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_src The source tensor to be repeated.
+ * @param acl_dst The destination tensor after repeating.
+ * @param repeat_array The array specifying the number of repetitions along each
+ * dimension.
+ */
+static void aclnn_repeat(ggml_backend_cann_context& ctx, aclTensor* acl_src,
+                         aclTensor* acl_dst, int64_t* repeat_array) {
+    // repeat tensor along each dim with repeat_array
+    aclIntArray* repeats = aclCreateIntArray(repeat_array, GGML_MAX_DIMS);
+
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    ACL_CHECK(aclnnRepeatGetWorkspaceSize(acl_src, repeats, acl_dst,
+                                          &workspaceSize, &executor));
+
+    if (workspaceSize > 0) {
+        // Memory from allocator will "free" immediately, and this memory
+        // will be alloced to other pointers, but it won't access before
+        // this async task end because all tasks in same stream will execute
+        // in queue.
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+    ACL_CHECK(
+        aclnnRepeat(workspaceAddr, workspaceSize, executor, ctx.stream()));
+    ACL_CHECK(aclDestroyIntArray(repeats));
+}
+
+void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+    ggml_tensor* src = dst->src[0];
+    GGML_ASSERT(ggml_can_repeat(src, dst));
+
+    aclTensor* acl_src = ggml_cann_create_tensor(src);
+    aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+
+    int64_t repeatsArray[] = {dst->ne[3] / src->ne[3], dst->ne[2] / src->ne[2],
+                              dst->ne[1] / src->ne[1], dst->ne[0] / src->ne[0]};
+
+    aclnn_repeat(ctx, acl_src, acl_dst, repeatsArray);
+    ACL_CHECK(aclDestroyTensor(acl_src));
+    ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+/**
+ * @brief Adds two tensors element-wise and stores the result in a destination
+ * tensor.
+ *
+ * This function performs the operation:
+ * \f[
+ *    dst = acl\_src0 + alpha \times acl\_src1
+ * \f]
+ * where alpha is a scalar value and defaults to 1.0f.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_src0 The first source tensor.
+ * @param acl_src1 The second source tensor.
+ * @param acl_dst The destination tensor where the result will be stored.
+ */
+static void aclnn_add(ggml_backend_cann_context& ctx, aclTensor* acl_src0,
+                      aclTensor* acl_src1, aclTensor* acl_dst) {
+    aclScalar* alpha = nullptr;
+    float alphaValue = 1.0f;
+    alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT);
+
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    ACL_CHECK(aclnnAddGetWorkspaceSize(acl_src0, acl_src1, alpha, acl_dst,
+                                       &workspaceSize, &executor));
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    ACL_CHECK(aclnnAdd(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+    ACL_CHECK(aclDestroyScalar(alpha));
+}
+
+void ggml_cann_add(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+    ggml_tensor* src0 = dst->src[0];
+    ggml_tensor* src1 = dst->src[1];
+    GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
+
+    aclTensor* acl_src0;
+    aclTensor* acl_src1;
+    aclTensor* acl_dst;
+
+    // Need bcast
+    if (!ggml_are_same_shape(src0, src1) && ggml_cann_need_bcast(src0, src1)) {
+        BCAST_SHAPE(src0, src1)
+        acl_src0 = ggml_cann_create_tensor(src0, BCAST_PARAM(src0));
+        acl_src1 = ggml_cann_create_tensor(src1, BCAST_PARAM(src1));
+        acl_dst = ggml_cann_create_tensor(dst, BCAST_PARAM(src0));
+    } else {
+        acl_src0 = ggml_cann_create_tensor(src0);
+        acl_src1 = ggml_cann_create_tensor(src1);
+        acl_dst = ggml_cann_create_tensor(dst);
+    }
+
+    aclnn_add(ctx, acl_src0, acl_src1, acl_dst);
+
+    ACL_CHECK(aclDestroyTensor(acl_src0));
+    ACL_CHECK(aclDestroyTensor(acl_src1));
+    ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+    ggml_tensor* src = dst->src[0];
+
+    GGML_ASSERT(src->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+    aclTensor* acl_src = ggml_cann_create_tensor(src);
+    aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+
+    float negative_slope;
+    memcpy(&negative_slope, dst->op_params, sizeof(float));
+    aclScalar* acl_negative_slope =
+        aclCreateScalar(&negative_slope, aclDataType::ACL_FLOAT);
+
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    ACL_CHECK(aclnnLeakyReluGetWorkspaceSize(
+        acl_src, acl_negative_slope, acl_dst, &workspaceSize, &executor));
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    ACL_CHECK(
+        aclnnLeakyRelu(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+    ACL_CHECK(aclDestroyScalar(acl_negative_slope));
+    ACL_CHECK(aclDestroyTensor(acl_src));
+    ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+/**
+ * @brief Concatenates a list of tensors along a specified dimension and stores
+ * the result in a destination tensor.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param tensorList The list of tensors to be concatenated.
+ * @param acl_dst The destination tensor where the concatenated result will be
+ * stored.
+ * @param concat_dim The dimension along which the tensors will be concatenated.
+ */
+static void aclnn_concat(ggml_backend_cann_context& ctx,
+                         aclTensorList* tensorList, aclTensor* acl_dst,
+                         int64_t concat_dim) {
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    ACL_CHECK(aclnnCatGetWorkspaceSize(tensorList, concat_dim, acl_dst,
+                                       &workspaceSize, &executor));
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    ACL_CHECK(aclnnCat(workspaceAddr, workspaceSize, executor, ctx.stream()));
+}
+
+void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+    ggml_tensor* src0 = dst->src[0];
+    ggml_tensor* src1 = dst->src[1];
+    aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
+    aclTensor* acl_src1 = ggml_cann_create_tensor(src1);
+    aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+
+    int64_t concat_dim = 1;
+    aclTensor* tensors[] = {acl_src0, acl_src1};
+    aclTensorList* tensorList = aclCreateTensorList(tensors, 2);
+    aclnn_concat(ctx, tensorList, acl_dst, concat_dim);
+
+    ACL_CHECK(aclDestroyTensorList(tensorList));
+    ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+/**
+ * @brief Creates a tensor with values starting from `start`, incremented by
+ * `step`, and ending before `stop`.
+ *
+ * This function performs the operation:
+ * \f[
+ *    \text {out }_{i+1}=\text {out }_i+\text {step}
+ * \f]
+ * the range is [start, stop).
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_dst The destination tensor where the values will be stored.
+ * @param start The starting value of the range.
+ * @param stop The ending value of the range (exclusive).
+ * @param step The step size between consecutive values.
+ * @param n_elements The number of elements in the destination tensor.
+ */
+static void aclnn_arange(ggml_backend_cann_context& ctx, aclTensor* acl_dst,
+                         float start, float stop, float step,
+                         int64_t n_elements) {
+    int64_t steps = (int64_t)std::ceil((stop - start) / step);
+    GGML_ASSERT(n_elements == steps);
+
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    aclScalar* acl_start = aclCreateScalar(&start, aclDataType::ACL_FLOAT);
+    aclScalar* acl_end = aclCreateScalar(&stop, aclDataType::ACL_FLOAT);
+    aclScalar* acl_step = aclCreateScalar(&step, aclDataType::ACL_FLOAT);
+
+    ACL_CHECK(aclnnArangeGetWorkspaceSize(acl_start, acl_end, acl_step, acl_dst,
+                                          &workspaceSize, &executor));
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    ACL_CHECK(
+        aclnnArange(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+    ACL_CHECK(aclDestroyScalar(acl_start));
+    ACL_CHECK(aclDestroyScalar(acl_end));
+    ACL_CHECK(aclDestroyScalar(acl_step));
+}
+
+void ggml_cann_arange(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+    aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+
+    int64_t n_elements = ggml_nelements(dst);
+    float start;
+    float stop;
+    float step;
+    memcpy(&start, (float*)dst->op_params + 0, sizeof(float));
+    memcpy(&stop, (float*)dst->op_params + 1, sizeof(float));
+    memcpy(&step, (float*)dst->op_params + 2, sizeof(float));
+
+    aclnn_arange(ctx, acl_dst, start, stop, step, n_elements);
+    ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+void ggml_cann_sqr(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+    dst->src[1] = dst->src[0];
+    ggml_cann_mul_div<aclnnMulGetWorkspaceSize, aclnnMul>(ctx, dst);
+}
+
+void ggml_cann_clamp(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+    ggml_tensor* src = dst->src[0];
+    GGML_ASSERT(src->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+    float min;
+    float max;
+    memcpy(&min, dst->op_params, sizeof(float));
+    memcpy(&max, (float*)dst->op_params + 1, sizeof(float));
+
+    aclTensor* acl_src = ggml_cann_create_tensor(src);
+    aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+
+    aclScalar* acl_min = aclCreateScalar(&min, aclDataType::ACL_FLOAT);
+    aclScalar* acl_max = aclCreateScalar(&max, aclDataType::ACL_FLOAT);
+
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    ACL_CHECK(aclnnClampGetWorkspaceSize(acl_src, acl_min, acl_max, acl_dst,
+                                         &workspaceSize, &executor));
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    ACL_CHECK(aclnnClamp(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+    ACL_CHECK(aclDestroyScalar(acl_min));
+    ACL_CHECK(aclDestroyScalar(acl_max));
+    ACL_CHECK(aclDestroyTensor(acl_src));
+    ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+void ggml_cann_scale(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+    ggml_tensor* src = dst->src[0];
+
+    // scale factor
+    float v;
+    memcpy(&v, dst->op_params, sizeof(float));
+
+    aclScalar* scale = aclCreateScalar(&v, aclDataType::ACL_FLOAT);
+    aclTensor* acl_src = ggml_cann_create_tensor(src);
+    aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    ACL_CHECK(aclnnMulsGetWorkspaceSize(acl_src, scale, acl_dst, &workspaceSize,
+                                        &executor));
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    ACL_CHECK(aclnnMuls(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+    ACL_CHECK(aclDestroyScalar(scale));
+    ACL_CHECK(aclDestroyTensor(acl_src));
+    ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+void ggml_cann_argsort(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+    ggml_tensor* src = dst->src[0];
+    enum ggml_sort_order order = (enum ggml_sort_order)dst->op_params[0];
+
+    aclTensor* acl_src = ggml_cann_create_tensor(src);
+    aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+    ggml_cann_pool_alloc temp_buffer_allocator(
+        ctx.pool(), ggml_nelements(dst) * sizeof(int64_t));
+    void* buffer = temp_buffer_allocator.get();
+    aclTensor* tmp_tensor =
+        ggml_cann_create_tensor(buffer, ACL_INT64, ggml_type_size(dst->type),
+                                dst->ne, dst->nb, GGML_MAX_DIMS);
+
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    ACL_CHECK(aclnnArgsortGetWorkspaceSize(
+        acl_src, -1, (order == GGML_SORT_ORDER_DESC ? true : false), tmp_tensor,
+        &workspaceSize, &executor));
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    ACL_CHECK(
+        aclnnArgsort(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+    workspaceSize = 0;
+    ACL_CHECK(aclnnCastGetWorkspaceSize(tmp_tensor,
+                                        ggml_cann_type_mapping(dst->type),
+                                        acl_dst, &workspaceSize, &executor));
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    ACL_CHECK(aclnnCast(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+    ACL_CHECK(aclDestroyTensor(acl_src));
+    ACL_CHECK(aclDestroyTensor(tmp_tensor));
+    ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+    ggml_tensor* src = dst->src[0];
+
+    aclTensor* acl_src = ggml_cann_create_tensor(src);
+    aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
+
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    std::vector<int64_t> normData = {dst->ne[0]};
+    aclIntArray* norm = aclCreateIntArray(normData.data(), normData.size());
+    ACL_CHECK(aclnnLayerNormGetWorkspaceSize(acl_src, norm, nullptr, nullptr,
+                                             eps, acl_dst, nullptr, nullptr,
+                                             &workspaceSize, &executor));
+
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    ACL_CHECK(
+        aclnnLayerNorm(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+    ACL_CHECK(aclDestroyIntArray(norm));
+    ACL_CHECK(aclDestroyTensor(acl_src));
+    ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+    ggml_tensor* src = dst->src[0];
+
+    aclTensor* acl_src = ggml_cann_create_tensor(src);
+    aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+
+    int n_groups = dst->op_params[0];
+
+    float eps;
+    memcpy(&eps, dst->op_params + 1, sizeof(float));
+
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    int64_t N = src->ne[3];
+    int64_t C = src->ne[2];
+    int64_t HxW = src->ne[1] * src->ne[0];
+
+    size_t type_size = ggml_type_size(src->type);
+    int64_t ne[] = {n_groups, N};
+    size_t nb[] = {type_size, type_size * n_groups};
+    size_t n_bytes = N * n_groups;
+
+    ggml_cann_pool_alloc temp_buffer_allocator(ctx.pool(), n_bytes * 2);
+    void* buffer = temp_buffer_allocator.get();
+    aclTensor* acl_mean_out = ggml_cann_create_tensor(
+        buffer, ACL_FLOAT, type_size, ne, nb, ACL_FORMAT_ND);
+    aclTensor* acl_rstd_out = ggml_cann_create_tensor(
+        (char*)buffer + n_bytes, ACL_FLOAT, type_size, ne, nb, ACL_FORMAT_ND);
+
+    ACL_CHECK(aclnnGroupNormGetWorkspaceSize(
+        acl_src, nullptr, nullptr, N, C, HxW, n_groups, eps, acl_dst,
+        acl_mean_out, acl_rstd_out, &workspaceSize, &executor));
+
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    ACL_CHECK(
+        aclnnGroupNorm(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+    ACL_CHECK(aclDestroyTensor(acl_src));
+    ACL_CHECK(aclDestroyTensor(acl_dst));
+    ACL_CHECK(aclDestroyTensor(acl_mean_out));
+    ACL_CHECK(aclDestroyTensor(acl_rstd_out));
+}
+
+void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+    ggml_tensor* src0 = dst->src[0];
+    ggml_tensor* src1 = dst->src[1];
+
+    size_t nb1 = ((int32_t*)dst->op_params)[0];
+    size_t nb2 = ((int32_t*)dst->op_params)[1];
+    size_t nb3 = ((int32_t*)dst->op_params)[2];
+    size_t offset = ((int32_t*)dst->op_params)[3];
+    bool inplace = (bool)((int32_t*)dst->op_params)[4];
+
+    size_t param_nb[] = {ggml_element_size(src0), nb1, nb2, nb3};
+
+    aclTensor* acl_dst = ggml_cann_create_tensor(
+        dst, src1->ne, param_nb, GGML_MAX_DIMS, ACL_FORMAT_ND, offset);
+    aclTensor* acl_src1 = ggml_cann_create_tensor(src1);
+
+    aclScalar* alpha = nullptr;
+    float alphaValue = 1.0f;
+    alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT);
+
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    if (!inplace) {
+        size_t cpy_size = ggml_nbytes(dst);
+        ACL_CHECK(aclrtMemcpyAsync(dst->data, cpy_size, src0->data, cpy_size,
+                                   ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
+        aclTensor* acl_src0 = ggml_cann_create_tensor(
+            src0, src1->ne, src0->nb, GGML_MAX_DIMS, ACL_FORMAT_ND, offset);
+        ACL_CHECK(aclnnAddGetWorkspaceSize(acl_src0, acl_src1, alpha, acl_dst,
+                                           &workspaceSize, &executor));
+        if (workspaceSize > 0) {
+            ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+            workspaceAddr = workspace_allocator.get();
+        }
+        ACL_CHECK(
+            aclnnAdd(workspaceAddr, workspaceSize, executor, ctx.stream()));
+        ACL_CHECK(aclDestroyTensor(acl_src0));
+    } else {
+        ACL_CHECK(aclnnInplaceAddGetWorkspaceSize(acl_dst, acl_src1, alpha,
+                                                  &workspaceSize, &executor));
+        if (workspaceSize > 0) {
+            ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+            workspaceAddr = workspace_allocator.get();
+        }
+        ACL_CHECK(aclnnInplaceAdd(workspaceAddr, workspaceSize, executor,
+                                  ctx.stream()));
+    }
+
+    ACL_CHECK(aclDestroyTensor(acl_src1));
+    ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+void ggml_cann_sum_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+    ggml_tensor* src = dst->src[0];
+
+    aclTensor* acl_src = ggml_cann_create_tensor(src);
+
+    GGML_ASSERT(dst->ne[0] == 1);
+    aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+
+    int64_t reduce_dims_host[] = {3};
+    aclIntArray* reduce_dims = aclCreateIntArray(reduce_dims_host, 1);
+
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    ACL_CHECK(aclnnReduceSumGetWorkspaceSize(
+        acl_src, reduce_dims, true, ggml_cann_type_mapping(src->type), acl_dst,
+        &workspaceSize, &executor));
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    ACL_CHECK(
+        aclnnReduceSum(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+    ACL_CHECK(aclDestroyTensor(acl_src));
+    ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+void ggml_cann_upsample_nearest2d(ggml_backend_cann_context& ctx,
+                                  ggml_tensor* dst) {
+    ggml_tensor* src = dst->src[0];
+    aclTensor* acl_src =
+        ggml_cann_create_tensor(src, nullptr, nullptr, 0, ACL_FORMAT_NCHW);
+    aclTensor* acl_dst =
+        ggml_cann_create_tensor(dst, nullptr, nullptr, 0, ACL_FORMAT_NCHW);
+
+    std::vector<int64_t> output_size{dst->ne[1], dst->ne[0]};
+    auto output_size_array = aclCreateIntArray(output_size.data(), 2);
+
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    ACL_CHECK(aclnnUpsampleNearest2dGetWorkspaceSize(
+        acl_src, output_size_array, acl_dst, &workspaceSize, &executor));
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    ACL_CHECK(aclnnUpsampleNearest2d(workspaceAddr, workspaceSize, executor,
+                                     ctx.stream()));
+
+    ACL_CHECK(aclDestroyIntArray(output_size_array));
+    ACL_CHECK(aclDestroyTensor(acl_src));
+    ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+/**
+ * @brief Pads a tensor with a specified value along each dimension.
+ *
+ * This function performs padding of the source tensor `acl_src` and stores the
+ * result in the destination tensor `acl_dst`. The padding values for each
+ * dimension are specified in the `paddings` array.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_src The source tensor to be padded.
+ * @param acl_dst The destination tensor where the padded result will be stored.
+ * @param paddings An array specifying the padding values for each dimension.
+ * The size of the array should be twice the number of dimensions of the tensor.
+ * @param value The value to be used for padding. The default value is 0.0.
+ */
+static void aclnn_pad(ggml_backend_cann_context& ctx, aclTensor* acl_src,
+                      aclTensor* acl_dst, int64_t* paddings,
+                      float value = 0.0f) {
+    aclIntArray* acl_pad = aclCreateIntArray(paddings, GGML_MAX_DIMS * 2);
+    aclScalar* acl_value = aclCreateScalar(&value, aclDataType::ACL_FLOAT);
+
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    ACL_CHECK(aclnnConstantPadNdGetWorkspaceSize(
+        acl_src, acl_pad, acl_value, acl_dst, &workspaceSize, &executor));
+
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    ACL_CHECK(aclnnConstantPadNd(workspaceAddr, workspaceSize, executor,
+                                 ctx.stream()));
+
+    ACL_CHECK(aclDestroyIntArray(acl_pad));
+    ACL_CHECK(aclDestroyScalar(acl_value));
+}
+
+void ggml_cann_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+    ggml_tensor* src = dst->src[0];
+    aclTensor* acl_src = ggml_cann_create_tensor(src);
+    aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+
+    // padding: value in the array means how much distance will be padding.
+    // the position of elements in the array means which dirction to padding,
+    // each position means: [dim0.front, dim0.behind, dim1.front, dim1.behind,
+    //                       dim2.front, dim2.behind, dim3.front, dim3.behind]
+    int64_t paddings[] = {
+        0, dst->ne[0] - src->ne[0], 0, dst->ne[1] - src->ne[1],
+        0, dst->ne[2] - src->ne[2], 0, dst->ne[3] - src->ne[3]};
+    aclnn_pad(ctx, acl_src, acl_dst, paddings);
+
+    ACL_CHECK(aclDestroyTensor(acl_dst));
+    ACL_CHECK(aclDestroyTensor(acl_src));
+}
+
+/**
+ * @brief Performs 2D average pooling on the input tensor and stores the result
+ * in the destination tensor.
+ *
+ * This function performs average pooling on the source tensor and stores the
+ * result in the destination tensor. The pooling parameters (kernel size,
+ * strides, padding) are specified in the `op_params` of the destination tensor.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param dst The destination tensor where the result will be stored. The source
+ * tensor is referenced by `dst->src[0]`.
+ */
+static void ggml_cann_avg_pool2d(ggml_backend_cann_context& ctx,
+                                 ggml_tensor* dst) {
+    ggml_tensor* src = dst->src[0];
+    GGML_ASSERT(src->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+    aclTensor* acl_src =
+        ggml_cann_create_tensor(src, nullptr, nullptr, 0, ACL_FORMAT_NCHW);
+    aclTensor* acl_dst =
+        ggml_cann_create_tensor(dst, nullptr, nullptr, 0, ACL_FORMAT_NCHW);
+
+    const int32_t* opts = (const int32_t*)dst->op_params;
+    const int k0 = opts[1];
+    const int k1 = opts[2];
+    const int s0 = opts[3];
+    const int s1 = opts[4];
+    const int p0 = opts[5];
+    const int p1 = opts[6];
+
+    std::vector<int64_t> kernel_dims = {k1, k0};
+    std::vector<int64_t> stride_dims = {s1, s0};
+    std::vector<int64_t> padding_avg_dims = {p1, p0};  // (padH, padW)
+
+    auto* kernel_size = aclCreateIntArray(kernel_dims.data(), 2);
+    auto* strides = aclCreateIntArray(stride_dims.data(), 2);
+    auto* paddings_avg = aclCreateIntArray(padding_avg_dims.data(), 2);
+
+    bool ceil_mode = false;
+    bool count_include_pad = true;
+    int64_t divisor_override = 0;
+    int8_t cube_math_type = 0;
+
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    ACL_CHECK(aclnnAvgPool2dGetWorkspaceSize(
+        acl_src, kernel_size, strides, paddings_avg, ceil_mode,
+        count_include_pad, divisor_override, cube_math_type, acl_dst,
+        &workspaceSize, &executor));
+
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+    ACL_CHECK(
+        aclnnAvgPool2d(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+    ACL_CHECK(aclDestroyTensor(acl_src));
+    ACL_CHECK(aclDestroyTensor(acl_dst));
+    ACL_CHECK(aclDestroyIntArray(kernel_size));
+    ACL_CHECK(aclDestroyIntArray(strides));
+    ACL_CHECK(aclDestroyIntArray(paddings_avg));
+}
+
+/**
+ * @brief Performs 2D max pooling on the input tensor and stores the result in
+ * the destination tensor.
+ *
+ * This function performs max pooling on the source tensor and stores the result
+ * in the destination tensor. The pooling parameters (kernel size, strides,
+ * padding) are specified in the `op_params` of the destination tensor.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param dst The destination tensor where the result will be stored. The source
+ * tensor is referenced by `dst->src[0]`.
+ */
+static void ggml_cann_max_pool2d(ggml_backend_cann_context& ctx,
+                                 ggml_tensor* dst) {
+    ggml_tensor* src = dst->src[0];
+    GGML_ASSERT(src->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+    aclTensor* acl_src =
+        ggml_cann_create_tensor(src, nullptr, nullptr, 0, ACL_FORMAT_NCHW);
+    aclTensor* acl_dst =
+        ggml_cann_create_tensor(dst, nullptr, nullptr, 0, ACL_FORMAT_NCHW);
+
+    const int32_t* opts = (const int32_t*)dst->op_params;
+    const int k0 = opts[1];
+    const int k1 = opts[2];
+    const int s0 = opts[3];
+    const int s1 = opts[4];
+    const int p0 = opts[5];
+    const int p1 = opts[6];
+
+    int64_t temp_ne[] = {src->ne[0] + p0 * 2, src->ne[1] + p1 * 2, src->ne[2],
+                         src->ne[3]};
+    size_t temp_nb[GGML_MAX_DIMS];
+
+    temp_nb[0] = ggml_element_size(src);
+    for (int i = 1; i < GGML_MAX_DIMS; i++) {
+        temp_nb[i] = temp_nb[i - 1] * temp_ne[i - 1];
+    }
+
+    ggml_cann_pool_alloc temp_buffer_allocator(
+        ctx.pool(), ggml_nbytes(src) + p0 * 2 + p1 * 2 * src->nb[1]);
+    void* buffer = temp_buffer_allocator.get();
+    aclTensor* tmp_tensor = ggml_cann_create_tensor(
+        buffer, ACL_FLOAT, ggml_element_size(src), temp_ne, temp_nb,
+        GGML_MAX_DIMS, ACL_FORMAT_NCHW);
+
+    // pad: see padding in ggml_cann_pad()
+    int64_t paddings[] = {p0, p0, p1, p1, 0, 0, 0, 0};
+    float value = -FLT_MAX;
+    aclnn_pad(ctx, acl_src, tmp_tensor, paddings, value);
+
+    // max_pool
+    std::vector<int64_t> kernel_dims = {k1, k0};
+    std::vector<int64_t> stride_dims = {s1, s0};
+    // padding_max_dims: [dim0_start, dim0_end, dim1_start, dim1_end]
+    std::vector<int64_t> padding_max_dims = {0, 0, 0, 0};
+    std::vector<int64_t> dilation_size = {1, 1};
+    auto* kernel_size = aclCreateIntArray(kernel_dims.data(), 2);
+    auto* strides = aclCreateIntArray(stride_dims.data(), 2);
+    auto* paddings_max = aclCreateIntArray(padding_max_dims.data(), 4);
+    auto* dilations = aclCreateIntArray(dilation_size.data(), 2);
+
+    bool ceil_mode = false;
+    int64_t auto_pads = 0;
+
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    ACL_CHECK(aclnnMaxPoolGetWorkspaceSize(
+        tmp_tensor, kernel_size, strides, auto_pads, paddings_max, dilations,
+        ceil_mode, acl_dst, &workspaceSize, &executor));
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    ACL_CHECK(
+        aclnnMaxPool(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+    ACL_CHECK(aclDestroyTensor(acl_src));
+    ACL_CHECK(aclDestroyTensor(acl_dst));
+    ACL_CHECK(aclDestroyTensor(tmp_tensor));
+    ACL_CHECK(aclDestroyIntArray(kernel_size));
+    ACL_CHECK(aclDestroyIntArray(strides));
+    ACL_CHECK(aclDestroyIntArray(paddings_max));
+    ACL_CHECK(aclDestroyIntArray(dilations));
+}
+
+void ggml_cann_pool2d(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+    const int32_t* opts = (const int32_t*)dst->op_params;
+    enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
+    switch (op) {
+        case GGML_OP_POOL_AVG:
+            ggml_cann_avg_pool2d(ctx, dst);
+            break;
+        case GGML_OP_POOL_MAX:
+            ggml_cann_max_pool2d(ctx, dst);
+            break;
+        case GGML_OP_POOL_COUNT:
+            GGML_ABORT("fatal error");
+            break;
+    }
+}
+
+/**
+ * @brief Copies data from the source tensor to the destination tensor.
+ *
+ * This function copies data from the source tensor `acl_src` to the destination
+ * tensor `acl_dst`.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_src The source tensor from which data will be copied.
+ * @param acl_dst The destination tensor where the data will be copied to.
+ */
+static void cann_copy(ggml_backend_cann_context& ctx, aclTensor* acl_src,
+                      aclTensor* acl_dst) {
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    ACL_CHECK(aclnnInplaceCopyGetWorkspaceSize(acl_dst, acl_src, &workspaceSize,
+                                               &executor));
+
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    ACL_CHECK(
+        aclnnInplaceCopy(workspaceAddr, workspaceSize, executor, ctx.stream()));
+}
+
+void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+    ggml_tensor* src = dst->src[0];
+
+    aclTensor* acl_src = ggml_cann_create_tensor(src);
+    aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+
+    ggml_cann_pool_alloc src_extra_allocator(ctx.pool(), sizeof(ggml_tensor));
+    ggml_cann_pool_alloc dst_extra_allocator(ctx.pool(), sizeof(ggml_tensor));
+    src->extra = src_extra_allocator.get();
+    dst->extra = dst_extra_allocator.get();
+    ACL_CHECK(aclrtMemcpyAsync(src->extra, sizeof(ggml_tensor), src,
+                               sizeof(ggml_tensor), ACL_MEMCPY_HOST_TO_DEVICE,
+                               ctx.stream()));
+    ACL_CHECK(aclrtMemcpyAsync(dst->extra, sizeof(ggml_tensor), dst,
+                               sizeof(ggml_tensor), ACL_MEMCPY_HOST_TO_DEVICE,
+                               ctx.stream()));
+
+    if ((dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32) &&
+        ggml_are_same_shape(src, dst)) {
+        cann_copy(ctx, acl_src, acl_dst);
+        ACL_CHECK(aclDestroyTensor(acl_src));
+        ACL_CHECK(aclDestroyTensor(acl_dst));
+        return;
+    }
+    // TODO: simplify
+    if (src->type == GGML_TYPE_F16) {
+        if (dst->type == GGML_TYPE_Q8_0) {
+            aclrtlaunch_ascendc_quantize_f16_q8_0(
+                24, ctx.stream(), src->data, dst->data,
+                ((ggml_tensor*)src->extra)->ne, ((ggml_tensor*)src->extra)->nb,
+                ((ggml_tensor*)dst->extra)->ne);
+            return;
+        }
+        if (dst->type == GGML_TYPE_Q4_0) {
+            aclrtlaunch_ascendc_quantize_f16_to_q4_0(
+                24, ctx.stream(), src->data, dst->data,
+                ((ggml_tensor*)src->extra)->ne, ((ggml_tensor*)src->extra)->nb,
+                ((ggml_tensor*)dst->extra)->ne);
+            return;
+        }
+        if (dst->type == GGML_TYPE_F16) {
+            if (ggml_are_same_shape(src, dst)) {
+                cann_copy(ctx, acl_src, acl_dst);
+                ACL_CHECK(aclDestroyTensor(acl_src));
+                ACL_CHECK(aclDestroyTensor(acl_dst));
+                return;
+            }
+            if (ggml_is_contiguous(dst)) {
+                const size_t src_type_size = ggml_type_size(src->type);
+                if (src->nb[0] == src_type_size) {
+                    // src0 is contigous on first dimension, copy by rows
+                    int64_t rows_num = ggml_nrows(src);
+
+                    aclrtlaunch_ascendc_dup_by_rows_fp16(
+                        rows_num, ctx.stream(), src->data, dst->data,
+                        ((ggml_tensor*)src->extra)->ne,
+                        ((ggml_tensor*)src->extra)->nb,
+                        ((ggml_tensor*)dst->extra)->ne,
+                        ((ggml_tensor*)dst->extra)->nb);
+                    return;
+                }
+                GGML_ABORT("fatal error");
+            }
+            GGML_ABORT("fatal error");
+        }
+        if (dst->type == GGML_TYPE_F32) {
+            if (ggml_are_same_shape(src, dst)) {
+                cann_copy(ctx, acl_src, acl_dst);
+                ACL_CHECK(aclDestroyTensor(acl_src));
+                ACL_CHECK(aclDestroyTensor(acl_dst));
+                return;
+            }
+            if (ggml_is_contiguous(dst)) {
+                const size_t src_type_size = ggml_type_size(src->type);
+                if (src->nb[0] == src_type_size) {
+                    // src0 is contigous on first dimension, copy by rows
+                    int64_t rows_num = ggml_nrows(src);
+                    aclrtlaunch_ascendc_dup_by_rows_fp16_to_fp32(
+                        rows_num, ctx.stream(), src->data, dst->data,
+                        ((ggml_tensor*)src->extra)->ne,
+                        ((ggml_tensor*)src->extra)->nb,
+                        ((ggml_tensor*)dst->extra)->ne,
+                        ((ggml_tensor*)dst->extra)->nb);
+                    return;
+                }
+                GGML_ABORT("fatal error");
+            }
+            GGML_ABORT("fatal error");
+        }
+        // TODO
+        GGML_ABORT("fatal error");
+    } else if (src->type == GGML_TYPE_F32) {
+        // TODO: if (src0->type == dst->type && ne00 == ne0 && nb00 == type_size
+        //          && nb0 == type_size)
+        if (dst->type == GGML_TYPE_Q8_0) {
+            aclrtlaunch_ascendc_quantize_f32_q8_0(
+                24, ctx.stream(), src->data, dst->data,
+                ((ggml_tensor*)src->extra)->ne, ((ggml_tensor*)src->extra)->nb,
+                ((ggml_tensor*)dst->extra)->ne);
+            return;
+        }
+        if (dst->type == GGML_TYPE_Q4_0) {
+            aclrtlaunch_ascendc_quantize_f32_to_q4_0(
+                24, ctx.stream(), src->data, dst->data,
+                ((ggml_tensor*)src->extra)->ne, ((ggml_tensor*)src->extra)->nb,
+                ((ggml_tensor*)dst->extra)->ne);
+            return;
+        }
+        if (dst->type == GGML_TYPE_F32) {
+            if (ggml_are_same_shape(src, dst)) {
+                cann_copy(ctx, acl_src, acl_dst);
+                ACL_CHECK(aclDestroyTensor(acl_src));
+                ACL_CHECK(aclDestroyTensor(acl_dst));
+                return;
+            }
+            if (ggml_is_contiguous(dst)) {
+                const size_t src_type_size = ggml_type_size(src->type);
+                if (src->nb[0] == src_type_size) {
+                    // src0 is contigous on first dimension, copy by rows
+                    int64_t rows_num = ggml_nrows(src);
+                    aclrtlaunch_ascendc_dup_by_rows_fp32(
+                        rows_num, ctx.stream(), src->data, dst->data,
+                        ((ggml_tensor*)src->extra)->ne,
+                        ((ggml_tensor*)src->extra)->nb,
+                        ((ggml_tensor*)dst->extra)->ne,
+                        ((ggml_tensor*)dst->extra)->nb);
+                    return;
+                }
+                GGML_ABORT("fatal error");
+            } else {
+                // TODO: dst not contiguous
+                GGML_ABORT("fatal error");
+            }
+        }
+        if (dst->type == GGML_TYPE_F16) {
+            if (ggml_are_same_shape(src, dst)) {
+                cann_copy(ctx, acl_src, acl_dst);
+                ACL_CHECK(aclDestroyTensor(acl_src));
+                ACL_CHECK(aclDestroyTensor(acl_dst));
+                return;
+            }
+            if (ggml_is_contiguous(dst)) {
+                const size_t src_type_size = ggml_type_size(src->type);
+                if (src->nb[0] == src_type_size) {
+                    // src0 is contigous on first dimension, copy by rows
+                    int64_t rows_num = ggml_nrows(src);
+                    aclrtlaunch_ascendc_dup_by_rows_fp32_to_fp16(
+                        rows_num, ctx.stream(), src->data, dst->data,
+                        ((ggml_tensor*)src->extra)->ne,
+                        ((ggml_tensor*)src->extra)->nb,
+                        ((ggml_tensor*)dst->extra)->ne,
+                        ((ggml_tensor*)dst->extra)->nb);
+                    return;
+                }
+                GGML_ABORT("fatal error");
+            }
+        }
+        // TODO
+        GGML_ABORT("fatal error");
+    } else {
+        if (ggml_are_same_shape(src, dst)) {
+            cann_copy(ctx, acl_src, acl_dst);
+            ACL_CHECK(aclDestroyTensor(acl_src));
+            ACL_CHECK(aclDestroyTensor(acl_dst));
+            return;
+        }
+        GGML_ABORT("fatal error");
+    }
+}
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+aclnnStatus aclnnRmsNormGetWorkspaceSize(const aclTensor* x,
+                                         const aclTensor* gamma, double epsilon,
+                                         const aclTensor* yOut,
+                                         const aclTensor* rstdOout,
+                                         uint64_t* workspaceSize,
+                                         aclOpExecutor** executor);
+aclnnStatus aclnnRmsNorm(void* workspace, uint64_t workspaceSize,
+                         aclOpExecutor* executor, aclrtStream stream);
+#ifdef __cplusplus
+}
+#endif
+
+/**
+ * @brief Creates an ACL tensor initialized with zeros using a provided buffer.
+ *
+ * This function initializes a tensor with zeros using the specified buffer and
+ * tensor parameters.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param buffer The buffer to be used for the tensor data.
+ * @param n_bytes The size of the buffer in bytes.
+ * @param ne An array specifying the extents (sizes) of each dimension of the
+ * tensor.
+ * @param dims The number of dimensions of the tensor.
+ * @param type The data type of the tensor.
+ * @param type_size The size of each element in the tensor data type.
+ * @return An ACL tensor initialized with zeros.
+ */
+static aclTensor* aclnn_zero(ggml_backend_cann_context& ctx, void* buffer,
+                             size_t n_bytes, int64_t* ne, int64_t dims,
+                             aclDataType type, size_t type_size) {
+    size_t nb[GGML_MAX_DIMS];
+    nb[0] = type_size;
+    for (int i = 1; i < dims; i++) {
+        nb[i] = nb[i - 1] * ne[i - 1];
+    }
+
+    ACL_CHECK(aclrtMemsetAsync(buffer, n_bytes, 0, n_bytes, ctx.stream()));
+    aclTensor* zero =
+        ggml_cann_create_tensor(buffer, type, type_size, ne, nb, dims);
+    return zero;
+}
+
+/**
+ * @brief Creates an ACL tensor initialized with ones using a provided buffer.
+ *
+ * This function initializes a tensor with ones using the specified buffer and
+ * tensor parameters.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param buffer The buffer to be used for the tensor data.
+ * @param n_bytes The size of the buffer in bytes.
+ * @param ne An array specifying the extents (sizes) of each dimension of the
+ * tensor.
+ * @param dims The number of dimensions of the tensor.
+ * @param type The data type of the tensor.
+ * @param type_size The size of each element in the tensor data type.
+ * @param value The value to be used for initializing the tensor (default
+ * is 1.0).
+ * @return An ACL tensor initialized with ones.
+ */
+static aclTensor* aclnn_ones(ggml_backend_cann_context& ctx, void* buffer,
+                             size_t n_bytes, int64_t* ne, int64_t dims,
+                             aclDataType type, size_t type_size,
+                             float value = 1.0f) {
+    aclTensor* acl_tensor =
+        aclnn_zero(ctx, buffer, n_bytes, ne, dims, type, type_size);
+    float alpha_host = 1.0f;
+    aclScalar* alpha = aclCreateScalar(&alpha_host, aclDataType::ACL_FLOAT);
+    aclScalar* other = aclCreateScalar(&value, aclDataType::ACL_FLOAT);
+
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    ACL_CHECK(aclnnInplaceAddsGetWorkspaceSize(acl_tensor, other, alpha,
+                                               &workspaceSize, &executor));
+
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+    ACL_CHECK(
+        aclnnInplaceAdds(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+    return acl_tensor;
+}
+
+void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+    ggml_tensor* src = dst->src[0];
+
+    aclTensor* acl_src = ggml_cann_create_tensor(src);
+    aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+
+    float eps;
+    memcpy(&eps, dst->op_params, sizeof(float));
+
+    GGML_ASSERT(eps > 0.0f);
+
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    size_t one_tensor_n_bytes = src->ne[0] * ggml_element_size(src);
+    ggml_cann_pool_alloc one_tensor_allocator(ctx.pool(), one_tensor_n_bytes);
+
+    aclTensor* acl_gamma = aclnn_ones(
+        ctx, one_tensor_allocator.get(), one_tensor_n_bytes, src->ne, 1,
+        ggml_cann_type_mapping(src->type), ggml_element_size(src));
+
+    size_t zero_tensor_n_bytes =
+        src->ne[1] * src->ne[2] * src->ne[3] * ggml_element_size(src);
+    ggml_cann_pool_alloc zero_tensor_allocator(ctx.pool(), zero_tensor_n_bytes);
+    aclTensor* acl_rstd =
+        aclnn_zero(ctx, zero_tensor_allocator.get(), zero_tensor_n_bytes,
+                   src->ne, GGML_MAX_DIMS, ggml_cann_type_mapping(src->type),
+                   ggml_element_size(src));
+
+    ACL_CHECK(aclnnRmsNormGetWorkspaceSize(
+        acl_src, acl_gamma, eps, acl_dst, acl_rstd, &workspaceSize, &executor));
+
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    ACL_CHECK(
+        aclnnRmsNorm(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+    ACL_CHECK(aclDestroyTensor(acl_src));
+    ACL_CHECK(aclDestroyTensor(acl_dst));
+    ACL_CHECK(aclDestroyTensor(acl_gamma));
+    ACL_CHECK(aclDestroyTensor(acl_rstd));
+}
+
+// TODO: performace is low.
+void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst,
+                         float value) {
+    ggml_tensor* src = dst->src[0];
+
+    aclTensor* acl_src = ggml_cann_create_tensor(src);
+    aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+
+    const int n_past = ((int32_t*)dst->op_params)[0];
+
+    size_t one_tensor_n_bytes = src->ne[0] * src->ne[1] * src->ne[2] *
+                                src->ne[3] * ggml_element_size(src);
+    ggml_cann_pool_alloc one_tensor_allocator(ctx.pool(), one_tensor_n_bytes);
+
+    aclTensor* mask_tensor =
+        aclnn_ones(ctx, one_tensor_allocator.get(), one_tensor_n_bytes, src->ne,
+                   GGML_MAX_DIMS, ggml_cann_type_mapping(src->type),
+                   ggml_element_size(src), value);
+
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    ACL_CHECK(aclnnInplaceTriuGetWorkspaceSize(mask_tensor, n_past + 1,
+                                               &workspaceSize, &executor));
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    ACL_CHECK(
+        aclnnInplaceTriu(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+    ACL_CHECK(aclnnTrilGetWorkspaceSize(acl_src, n_past + 1, acl_dst,
+                                        &workspaceSize, &executor));
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    ACL_CHECK(aclnnTril(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+    aclScalar* alpha = nullptr;
+    float alphaValue = 1.0f;
+    alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT);
+
+    ACL_CHECK(aclnnInplaceAddGetWorkspaceSize(acl_dst, mask_tensor, alpha,
+                                              &workspaceSize, &executor));
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+    ACL_CHECK(
+        aclnnInplaceAdd(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+    ACL_CHECK(aclDestroyScalar(alpha));
+    ACL_CHECK(aclDestroyTensor(mask_tensor));
+    ACL_CHECK(aclDestroyTensor(acl_src));
+    ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+/**
+ * @brief Casts the data type of a source tensor to a destination tensor.
+ *
+ * This function casts the data type of the source tensor `acl_src` to the
+ * specified data type `cast_data_type` and stores the result in the destination
+ * tensor `acl_dst`.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_src The source tensor whose data type will be casted.
+ * @param acl_dst The destination tensor where the casted result will be stored.
+ * @param cast_data_type The target data type to which the source tensor will be
+ * casted.
+ */
+static void aclnn_cast(ggml_backend_cann_context& ctx, aclTensor* acl_src,
+                       aclTensor* acl_dst, aclDataType cast_data_type) {
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    ACL_CHECK(aclnnCastGetWorkspaceSize(acl_src, cast_data_type, acl_dst,
+                                        &workspaceSize, &executor));
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    ACL_CHECK(aclnnCast(workspaceAddr, workspaceSize, executor, ctx.stream()));
+}
+
+/**
+ * @brief Permutes the dimensions of a tensor according to a specified order.
+ *
+ * This function permutes the dimensions of the source tensor `acl_src`
+ * according to the order specified in the `new_dim` array and stores the result
+ * in the destination tensor `acl_dst`.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_src The source tensor whose dimensions will be permuted.
+ * @param acl_dst The destination tensor where the permuted result will be
+ * stored.
+ * @param new_dim An array specifying the new order of dimensions for the
+ * tensor.
+ * @param dims The number of dimensions in the tensor.
+ */
+static void aclnn_permute(ggml_backend_cann_context& ctx, aclTensor* acl_src,
+                          aclTensor* acl_dst, int64_t* new_dim, uint64_t dims) {
+    aclIntArray* acl_dims = aclCreateIntArray(new_dim, dims);
+
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    ACL_CHECK(aclnnPermuteGetWorkspaceSize(acl_src, acl_dims, acl_dst,
+                                           &workspaceSize, &executor));
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    ACL_CHECK(
+        aclnnPermute(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+    ACL_CHECK(aclDestroyIntArray(acl_dims));
+}
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+aclnnStatus aclnnIm2colGetWorkspaceSize(const aclTensor* self,
+                                        const aclIntArray* kernelSize,
+                                        const aclIntArray* dilation,
+                                        const aclIntArray* padding,
+                                        const aclIntArray* stride,
+                                        aclTensor* out, uint64_t* workspaceSize,
+                                        aclOpExecutor** executor);
+aclnnStatus aclnnIm2col(void* workspace, uint64_t workspaceSize,
+                        aclOpExecutor* executor, aclrtStream stream);
+#ifdef __cplusplus
+}
+#endif
+
+static void ggml_cann_im2col_2d_post_process(ggml_backend_cann_context& ctx,
+                                             ggml_tensor* dst,
+                                             ggml_tensor* src1,
+                                             aclTensor* tmp_cast_tensor,
+                                             aclTensor* tmp_im2col_tensor) {
+    // Permute: [N, IC * KH * KW, OW * OH] -> [N, OW * OH, IC * KH * KW]
+    int64_t dst_ne[] = {dst->ne[0], dst->ne[1] * dst->ne[2], dst->ne[3]};
+    size_t dst_nb[] = {dst->nb[0], dst->nb[1], dst->nb[3]};
+    aclTensor* acl_dst =
+        ggml_cann_create_tensor(dst, dst_ne, dst_nb, GGML_MAX_DIMS - 1);
+
+    int64_t permute_dim[] = {0, 2, 1};
+    if (src1->type != dst->type) {
+        aclnn_permute(ctx, tmp_cast_tensor, acl_dst, permute_dim, 3);
+    } else {
+        aclnn_permute(ctx, tmp_im2col_tensor, acl_dst, permute_dim, 3);
+    }
+
+    // release
+    ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+static void ggml_cann_im2col_1d_post_process(
+    ggml_backend_cann_context& ctx, ggml_tensor* dst, ggml_tensor* src1,
+    aclTensor* tmp_cast_tensor, aclTensor* tmp_im2col_tensor,
+    const std::vector<int64_t>& im2col_op_params) {
+    // get params
+    const int64_t KH = im2col_op_params[0];
+    const int64_t KW = im2col_op_params[1];
+    const int64_t IW = im2col_op_params[2];
+    const int64_t IC = im2col_op_params[3];
+    const int64_t N = im2col_op_params[4];
+    const int64_t OH = im2col_op_params[5];
+    const int64_t OW = im2col_op_params[6];
+    const int64_t s0 = im2col_op_params[7];
+    const int64_t p0 = im2col_op_params[8];
+    const int64_t d0 = im2col_op_params[9];
+    const int64_t n_bytes_factor = im2col_op_params[10];
+
+    // Permute: [N, IC * KH * KW, OW * OH] ->
+    // [N, OW * OH * n_bytes_factor, IC * KH * KW]
+    aclTensor* tmp_permute_tensor = nullptr;
+    ggml_cann_pool_alloc tmp_permute_allocator(ctx.pool());
+    tmp_permute_allocator.alloc(ggml_nbytes(dst) * n_bytes_factor);
+    void* tmp_permute_buffer = tmp_permute_allocator.get();
+
+    int64_t tmp_permute_ne[] = {IC * KH * KW, OW * OH * n_bytes_factor, N};
+    size_t tmp_permute_nb[GGML_MAX_DIMS - 1];
+    tmp_permute_nb[0] = ggml_type_size(dst->type);
+    for (int i = 1; i < GGML_MAX_DIMS - 1; i++) {
+        tmp_permute_nb[i] = tmp_permute_nb[i - 1] * tmp_permute_ne[i - 1];
+    }
+
+    tmp_permute_tensor = ggml_cann_create_tensor(
+        tmp_permute_buffer, ggml_cann_type_mapping(dst->type),
+        ggml_type_size(dst->type), tmp_permute_ne, tmp_permute_nb,
+        GGML_MAX_DIMS - 1, ACL_FORMAT_ND);
+
+    int64_t permute_dim[] = {0, 2, 1};
+    if (src1->type != dst->type) {
+        aclnn_permute(ctx, tmp_cast_tensor, tmp_permute_tensor, permute_dim, 3);
+    } else {
+        aclnn_permute(ctx, tmp_im2col_tensor, tmp_permute_tensor, permute_dim,
+                      3);
+    }
+
+    // number of times the kernel moves in W dimension
+    const int n_step_w = (IW + 2 * p0 - d0 * (KW - 1) - 1) / s0 + 1;
+    size_t offset;
+    void *cur_dst_buffer = dst->data, *cur_permute_buffer = tmp_permute_buffer;
+
+    // memory copy with offset to restore 1D im2col from 2d
+    if (IC > 1) {
+        offset = IC * KH * KW * n_step_w * ggml_type_size(dst->type);
+        size_t size_cpy = KH * KW * ggml_type_size(dst->type);
+
+        for (int c = 0; c < IC; c++) {
+            cur_permute_buffer = (char*)tmp_permute_buffer + offset +
+                                 KH * KW * c * ggml_type_size(dst->type);
+            cur_dst_buffer = (char*)dst->data +
+                             c * KH * KW * n_step_w * ggml_type_size(dst->type);
+
+            for (int i = 0; i < n_step_w; i++) {
+                ACL_CHECK(aclrtMemcpyAsync(
+                    cur_dst_buffer, size_cpy, cur_permute_buffer, size_cpy,
+                    ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
+                cur_dst_buffer =
+                    (char*)cur_dst_buffer + KH * KW * ggml_type_size(dst->type);
+                cur_permute_buffer = (char*)cur_permute_buffer +
+                                     KH * KW * IC * ggml_type_size(dst->type);
+            }
+        }
+    } else {
+        offset = KH * KW * n_step_w *
+                 ggml_type_size(dst->type);  // equal to ggml_nbytes(dst)
+        ACL_CHECK(aclrtMemcpyAsync(dst->data, offset,
+                                   (char*)tmp_permute_buffer + offset, offset,
+                                   ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
+    }
+
+    // release
+    ACL_CHECK(aclDestroyTensor(tmp_permute_tensor));
+}
+
+void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+    ggml_tensor* src0 = dst->src[0];  // kernel
+    ggml_tensor* src1 = dst->src[1];  // input
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
+
+    GGML_TENSOR_BINARY_OP_LOCALS;
+
+    // aclnnIm2col only works on 2D. set s1, p1, d1 to 1 to perform 2D
+    // im2col and do post-processing to restore it to 1D.
+    const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
+    const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
+    const int32_t s1 = is_2D ? ((const int32_t*)(dst->op_params))[1] : 1;
+    const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
+    const int32_t p1 = is_2D ? ((const int32_t*)(dst->op_params))[3] : 1;
+    const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
+    const int32_t d1 = is_2D ? ((const int32_t*)(dst->op_params))[5] : 1;
+
+    const int64_t N = ne13;
+    const int64_t IC = ne12;
+    const int64_t KH = ne01;
+    const int64_t KW = ne00;
+    const int64_t IW = ne10;
+
+    const int64_t OH = is_2D ? ne2 : 1;
+    const int64_t OW = ne1;
+
+    GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    // memory allocated increased to 3x when is_2D == false
+    const int64_t n_bytes_factor = is_2D ? 1 : 3;
+
+    // im2col: [N,C,H,W] -> [N, IC * KH * KW, OW * OH * n_bytes_factor]
+    aclTensor* acl_src1 = ggml_cann_create_tensor(src1);
+    int64_t tmp_im2col_ne[] = {OW * OH * n_bytes_factor, IC * KH * KW, N};
+    size_t tmp_im2col_nb[GGML_MAX_DIMS - 1];
+
+    tmp_im2col_nb[0] = ggml_type_size(src1->type);
+    for (int i = 1; i < GGML_MAX_DIMS - 1; i++) {
+        tmp_im2col_nb[i] = tmp_im2col_nb[i - 1] * tmp_im2col_ne[i - 1];
+    }
+
+    // Calculate im2col.
+    // If dst is f16, tmp_buffer is f32, we need alloc src.typesize *
+    // dst.elemcount.
+    ggml_cann_pool_alloc im2col_allocator(
+        ctx.pool(),
+        ggml_nelements(dst) * ggml_element_size(src1) * n_bytes_factor);
+    void* tmp_im2col_buffer = im2col_allocator.get();
+
+    aclTensor* tmp_im2col_tensor = ggml_cann_create_tensor(
+        tmp_im2col_buffer, ggml_cann_type_mapping(src1->type),
+        ggml_type_size(src1->type), tmp_im2col_ne, tmp_im2col_nb,
+        GGML_MAX_DIMS - 1, ACL_FORMAT_ND);
+
+    std::vector<int64_t> kernel_dims = {KH, KW};
+    std::vector<int64_t> dilation_size = {d1, d0};
+    std::vector<int64_t> padding_dims = {p1, p0};
+    std::vector<int64_t> stride_dims = {s1, s0};
+    auto* kernel_size = aclCreateIntArray(kernel_dims.data(), 2);
+    auto* dilations = aclCreateIntArray(dilation_size.data(), 2);
+    auto* paddings = aclCreateIntArray(padding_dims.data(), 2);
+    auto* strides = aclCreateIntArray(stride_dims.data(), 2);
+
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    ACL_CHECK(aclnnIm2colGetWorkspaceSize(acl_src1, kernel_size, dilations,
+                                          paddings, strides, tmp_im2col_tensor,
+                                          &workspaceSize, &executor));
+
+    ggml_cann_pool_alloc workspace_allocator(ctx.pool());
+    if (workspaceSize > 0) {
+        workspace_allocator.alloc(workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    ACL_CHECK(
+        aclnnIm2col(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+    // Cast if dst is f16.
+    aclTensor* tmp_cast_tensor = nullptr;
+    ggml_cann_pool_alloc tmp_cast_allocator(ctx.pool());
+    void* tmp_cast_buffer = nullptr;
+    if (src1->type != dst->type) {
+        tmp_cast_allocator.alloc(ggml_nbytes(dst) * n_bytes_factor);
+        tmp_cast_buffer = tmp_cast_allocator.get();
+        size_t temp_cast_nb[GGML_MAX_DIMS - 1];
+        temp_cast_nb[0] = ggml_type_size(dst->type);
+        for (int i = 1; i < GGML_MAX_DIMS - 1; i++) {
+            temp_cast_nb[i] = temp_cast_nb[i - 1] * tmp_im2col_ne[i - 1];
+        }
+
+        tmp_cast_tensor = ggml_cann_create_tensor(
+            tmp_cast_buffer, ggml_cann_type_mapping(dst->type),
+            ggml_type_size(dst->type), tmp_im2col_ne, temp_cast_nb,
+            GGML_MAX_DIMS - 1, ACL_FORMAT_ND);
+        aclnn_cast(ctx, tmp_im2col_tensor, tmp_cast_tensor,
+                   ggml_cann_type_mapping(dst->type));
+    }
+
+    // post-processing
+    if (is_2D) {
+        ggml_cann_im2col_2d_post_process(ctx, dst, src1, tmp_cast_tensor,
+                                         tmp_im2col_tensor);
+    } else {
+        std::vector<int64_t> im2col_op_params = {
+            KH, KW, IW, IC, N, OH, OW, s0, p0, d0, n_bytes_factor};
+        ggml_cann_im2col_1d_post_process(ctx, dst, src1, tmp_cast_tensor,
+                                         tmp_im2col_tensor, im2col_op_params);
+    }
+
+    // release
+    ACL_CHECK(aclDestroyTensor(acl_src1));
+    ACL_CHECK(aclDestroyTensor(tmp_im2col_tensor));
+    ACL_CHECK(aclDestroyTensor(tmp_cast_tensor));
+    ACL_CHECK(aclDestroyIntArray(kernel_size));
+    ACL_CHECK(aclDestroyIntArray(dilations));
+    ACL_CHECK(aclDestroyIntArray(paddings));
+    ACL_CHECK(aclDestroyIntArray(strides));
+}
+
+/**
+ * @brief Applies element-wise exponential function to the elements of a tensor.
+ *
+ * This function computes the exponential of each element in the source tensor
+ * `acl_src` and stores the result back into the same tensor.
+ * The operation is defined as:
+ * \f[
+ *     \text {acl_src }_i=e^{acl\_src_i}
+ * \f]
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_src The tensor on which the exponential function will be applied.
+ */
+static void aclnn_exp(ggml_backend_cann_context& ctx, aclTensor* acl_src) {
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    ACL_CHECK(
+        aclnnInplaceExpGetWorkspaceSize(acl_src, &workspaceSize, &executor));
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    ACL_CHECK(
+        aclnnInplaceExp(workspaceAddr, workspaceSize, executor, ctx.stream()));
+}
+
+/**
+ * @brief Multiplies elements of a tensor by a scalar value, optionally
+ * in-place.
+ *
+ * This function multiplies each element of the source tensor `acl_src` by the
+ * scalar `scale` and stores the result in the destination tensor `acl_dst`. If
+ * `inplace` is true, `acl_dst` will not be used and the operation is performed
+ *  in-place on `acl_src`.
+ * The operation is defined as:
+ * \f[
+ *     \text {acl_dst }_i=\text {acl_src }_i \times \text {scale}
+ * \f]
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_src The source tensor whose elements will be multiplied.
+ * @param scale The scalar value by which each element of `acl_src` will be
+ * multiplied.
+ * @param acl_dst The destination tensor where the result will be stored if
+ * `inplace` is false.
+ * @param inplace Flag indicating whether to perform the operation in-place on
+ * `acl_src`.
+ */
+static void aclnn_muls(ggml_backend_cann_context& ctx, aclTensor* acl_src,
+                       float scale, aclTensor* acl_dst, bool inplace) {
+    aclScalar* acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT);
+
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    if (inplace) {
+        ACL_CHECK(aclnnInplaceMulsGetWorkspaceSize(acl_src, acl_scale,
+                                                   &workspaceSize, &executor));
+        if (workspaceSize > 0) {
+            ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+            workspaceAddr = workspace_allocator.get();
+        }
+
+        ACL_CHECK(aclnnInplaceMuls(workspaceAddr, workspaceSize, executor,
+                                   ctx.stream()));
+    } else {
+        ACL_CHECK(aclnnMulsGetWorkspaceSize(acl_src, acl_scale, acl_dst,
+                                            &workspaceSize, &executor));
+        if (workspaceSize > 0) {
+            ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+            workspaceAddr = workspace_allocator.get();
+        }
+
+        ACL_CHECK(
+            aclnnMuls(workspaceAddr, workspaceSize, executor, ctx.stream()));
+    }
+
+    ACL_CHECK(aclDestroyScalar(acl_scale));
+}
+
+/**
+ * @brief Performs an in-place element-wise multiplication of two tensors.
+ *
+ * This function performs an element-wise multiplication of the tensors
+ * `acl_src` and `acl_other` and stores the result in `acl_src`.
+ * The operation is defined as:
+ * \f[
+ *     \text {acl_src }_i=\text {acl_src }_i \times \text {acl_other }_i
+ * \f]
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_src The source tensor where the multiplication result will be
+ * stored.
+ * @param acl_other The tensor whose elements will be multiplied with `acl_src`.
+ */
+static void aclnn_inplace_mul(ggml_backend_cann_context& ctx,
+                              aclTensor* acl_src, aclTensor* acl_other) {
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    ACL_CHECK(aclnnInplaceMulGetWorkspaceSize(acl_src, acl_other,
+                                              &workspaceSize, &executor));
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    ACL_CHECK(
+        aclnnInplaceMul(workspaceAddr, workspaceSize, executor, ctx.stream()));
+}
+
+/**
+ * @brief Performs element-wise multiplication of two tensors and stores the
+ * result in a destination tensor.
+ *
+ * This function performs element-wise multiplication of the tensors `acl_src`
+ * and `acl_other` and stores the result in the destination tensor `acl_dst`.
+ * The operation is defined as:
+ * \f[
+ *     \text {acl_dst }_i=\text {acl_src }_i \times \text {acl_other }_i
+ * \f]
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_src The first tensor for element-wise multiplication.
+ * @param acl_other The second tensor for element-wise multiplication.
+ * @param acl_dst The destination tensor where the result will be stored.
+ */
+static void aclnn_mul(ggml_backend_cann_context& ctx, aclTensor* acl_src,
+                      aclTensor* acl_other, aclTensor* acl_dst) {
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    ACL_CHECK(aclnnMulGetWorkspaceSize(acl_src, acl_other, acl_dst,
+                                       &workspaceSize, &executor));
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    ACL_CHECK(aclnnMul(workspaceAddr, workspaceSize, executor, ctx.stream()));
+}
+
+/**
+ * @brief Applies element-wise cosine function to the elements of a tensor.
+ *
+ * This function computes the cosine of each element in the source tensor
+ * `acl_src` and stores the result in the destination tensor `acl_dst`. The
+ * operation is defined as: \f[ \text {acl_dst }_i=\cos \left(\text {acl_src
+ * }_i\right) \f]
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_src The source tensor on which the cosine function will be
+ * applied.
+ * @param acl_dst The destination tensor where the cosine results will be
+ * stored.
+ */
+static void aclnn_cos(ggml_backend_cann_context& ctx, aclTensor* acl_src,
+                      aclTensor* acl_dst) {
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    ACL_CHECK(
+        aclnnCosGetWorkspaceSize(acl_src, acl_dst, &workspaceSize, &executor));
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    ACL_CHECK(aclnnCos(workspaceAddr, workspaceSize, executor, ctx.stream()));
+}
+
+/**
+ * @brief Applies element-wise sine function to the elements of a tensor.
+ *
+ * This function computes the sine of each element in the source tensor
+ `acl_src`
+ * and stores the result in the destination tensor `acl_dst`.
+ * The operation is defined as:
+ * \f[
+ *     \text {acl_dst }_i=\sin \left(\text {acl_src }_i\right)
+ * \f]
+
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_src The source tensor on which the sine function will be applied.
+ * @param acl_dst The destination tensor where the sine results will be stored.
+ */
+static void aclnn_sin(ggml_backend_cann_context& ctx, aclTensor* acl_src,
+                      aclTensor* acl_dst) {
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    ACL_CHECK(
+        aclnnSinGetWorkspaceSize(acl_src, acl_dst, &workspaceSize, &executor));
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    ACL_CHECK(aclnnSin(workspaceAddr, workspaceSize, executor, ctx.stream()));
+}
+
+void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx,
+                                  ggml_tensor* dst) {
+    const ggml_tensor* src = dst->src[0];
+
+    GGML_ASSERT(src->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+    const int dim = dst->op_params[0];
+    const int max_period = dst->op_params[1];
+    int half = dim / 2;
+
+    aclTensor* acl_src = ggml_cann_create_tensor(src);
+
+    // arange: [0, ..., half)
+    float start = 0;
+    float stop = half;
+    float step = 1;
+    int64_t n_elements_arange = half;
+    int64_t tmp_arange_ne[] = {half};
+    size_t tmp_arange_nb[] = {sizeof(dst->type)};
+
+    ggml_cann_pool_alloc arange_allocator(ctx.pool(), half * sizeof(dst->type));
+    void* tmp_arange_buffer = arange_allocator.get();
+    aclTensor* tmp_arange_tensor = ggml_cann_create_tensor(
+        tmp_arange_buffer, ggml_cann_type_mapping(dst->type),
+        ggml_type_size(dst->type), tmp_arange_ne, tmp_arange_nb,
+        GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+
+    aclnn_arange(ctx, tmp_arange_tensor, start, stop, step, n_elements_arange);
+
+    // freq
+    float freq_param = -logf(max_period) / half;
+    bool inplace = true;
+    aclnn_muls(ctx, tmp_arange_tensor, freq_param, nullptr, inplace);
+    aclnn_exp(ctx, tmp_arange_tensor);
+
+    // permute: src [0,1,2,3]->[0,1,3,2]
+    int64_t tmp_permute_ne[] = {src->ne[1], src->ne[0], src->ne[2], src->ne[3]};
+    size_t tmp_permute_nb[GGML_MAX_DIMS];
+    tmp_permute_nb[0] = ggml_type_size(src->type);
+    for (int i = 1; i < GGML_MAX_DIMS; i++) {
+        tmp_permute_nb[i] = tmp_permute_nb[i - 1] * tmp_permute_ne[i - 1];
+    }
+
+    ggml_cann_pool_alloc permute_allocator(ctx.pool(), ggml_nbytes(src));
+    void* tmp_permute_buffer = permute_allocator.get();
+    aclTensor* tmp_permute_tenosr = ggml_cann_create_tensor(
+        tmp_permute_buffer, ggml_cann_type_mapping(src->type),
+        ggml_type_size(src->type), tmp_permute_ne, tmp_permute_nb,
+        GGML_MAX_DIMS, ACL_FORMAT_ND);
+    int64_t permute_dim[] = {0, 1, 3, 2};
+    int64_t num_dims = 4;
+    aclnn_permute(ctx, acl_src, tmp_permute_tenosr, permute_dim, num_dims);
+
+    // timestep * freq
+    int64_t tmp_mul_ne[] = {src->ne[1] * half, src->ne[0], src->ne[2],
+                            src->ne[3]};
+    size_t tmp_mul_nb[GGML_MAX_DIMS];
+    tmp_mul_nb[0] = ggml_type_size(src->type);
+    for (int i = 1; i < GGML_MAX_DIMS; i++) {
+        tmp_mul_nb[i] = tmp_mul_nb[i - 1] * tmp_mul_ne[i - 1];
+    }
+
+    int mul_nelements =
+        src->ne[1] * half * src->ne[0] * src->ne[2] * src->ne[3];
+
+    ggml_cann_pool_alloc mul_allocator(
+        ctx.pool(), mul_nelements * ggml_type_size(src->type));
+    void* tmp_mul_buffer = mul_allocator.get();
+    aclTensor* tmp_mul_tensor = ggml_cann_create_tensor(
+        tmp_mul_buffer, ggml_cann_type_mapping(src->type),
+        ggml_type_size(src->type), tmp_mul_ne, tmp_mul_nb, GGML_MAX_DIMS,
+        ACL_FORMAT_ND);
+    aclnn_mul(ctx, tmp_permute_tenosr, tmp_arange_tensor, tmp_mul_tensor);
+
+    // cos
+    ggml_cann_pool_alloc cos_allocator(
+        ctx.pool(), mul_nelements * ggml_type_size(src->type));
+    void* tmp_cos_buffer = cos_allocator.get();
+    aclTensor* tmp_cos_tensor = ggml_cann_create_tensor(
+        tmp_cos_buffer, ggml_cann_type_mapping(dst->type),
+        ggml_type_size(dst->type), tmp_mul_ne, tmp_mul_nb, GGML_MAX_DIMS,
+        ACL_FORMAT_ND);
+
+    aclnn_cos(ctx, tmp_mul_tensor, tmp_cos_tensor);
+
+    // sin
+    ggml_cann_pool_alloc sin_allocator(
+        ctx.pool(), mul_nelements * ggml_type_size(src->type));
+    void* tmp_sin_buffer = sin_allocator.get();
+    aclTensor* tmp_sin_tensor = ggml_cann_create_tensor(
+        tmp_sin_buffer, ggml_cann_type_mapping(dst->type),
+        ggml_type_size(dst->type), tmp_mul_ne, tmp_mul_nb, GGML_MAX_DIMS,
+        ACL_FORMAT_ND);
+
+    aclnn_sin(ctx, tmp_mul_tensor, tmp_sin_tensor);
+
+    // concat
+    int64_t concat_dim = 3;
+    aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+    aclTensor* tensors[] = {tmp_cos_tensor, tmp_sin_tensor};
+    aclTensorList* tensorList = aclCreateTensorList(tensors, 2);
+    aclnn_concat(ctx, tensorList, acl_dst, concat_dim);
+
+    // release
+    // segmentation fault when delete both tensorList and his elements.
+    ACL_CHECK(aclDestroyTensorList(tensorList));
+    ACL_CHECK(aclDestroyTensor(acl_src));
+    ACL_CHECK(aclDestroyTensor(tmp_arange_tensor));
+    ACL_CHECK(aclDestroyTensor(tmp_permute_tenosr));
+    ACL_CHECK(aclDestroyTensor(tmp_mul_tensor));
+    ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+/**
+ * @brief Fills a tensor with a scalar value.
+ *
+ * This function fills the destination tensor `acl_dst` with the scalar value
+ * `scalar`.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param scalar The scalar value used to fill the tensor.
+ * @param acl_dst The destination tensor to be filled with the scalar value.
+ */
+static void aclnn_fill_scalar(ggml_backend_cann_context& ctx, float scalar,
+                              aclTensor* acl_dst) {
+    auto acl_scalar = aclCreateScalar(&scalar, aclDataType::ACL_FLOAT);
+
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    ACL_CHECK(aclnnInplaceFillScalarGetWorkspaceSize(
+        acl_dst, acl_scalar, &workspaceSize, &executor));
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    ACL_CHECK(aclnnInplaceFillScalar(workspaceAddr, workspaceSize, executor,
+                                     ctx.stream()));
+    ACL_CHECK(aclDestroyScalar(acl_scalar));
+}
+
+/**
+ * @brief Raises each element of a tensor to the power of the corresponding
+ * element in another tensor.
+ *
+ * This function computes the element-wise power of the destination tensor
+ * `acl_dst` raised to the power of the exponent tensor `acl_exp`.
+ * The operation is defined as:
+ * \f[
+ *     \text {acl_dst }_i=acl\_dst_i^{\text {acl_exp }_i}
+ * \f]
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_dst The destination tensor, which also serves as the base tensor.
+ * @param acl_exp The exponent tensor, each element of which is used to raise
+ * the corresponding element in the destination tensor.
+ */
+static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx,
+                                    aclTensor* acl_dst, aclTensor* acl_exp) {
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    ACL_CHECK(aclnnInplacePowTensorTensorGetWorkspaceSize(
+        acl_dst, acl_exp, &workspaceSize, &executor));
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    ACL_CHECK(aclnnInplacePowTensorTensor(workspaceAddr, workspaceSize,
+                                          executor, ctx.stream()));
+}
+
+/**
+ * @brief   Applies the Alibi (Attention with Linear Biases) mechanism to the
+ * @details This function implements the Alibi mechanism, which introduces
+ *          learnable biases into the attention scores to simulate relative
+ *          position encoding without the need for explicit positional
+ *          embeddings.
+ *
+ * @param ctx          The backend CANN context for executing operations.
+ * @param acl_src      The source tensor representing the query or key.
+ * @param acl_position The position tensor containing relative positions.
+ * @param acl_dst      The destination tensor where the result will be stored.
+ * @param n_head       The number of attention heads.
+ * @param src_ne       The dimensions of the source tensor.
+ * @param src_nb0      The byte size of the first dimension of the source
+ tensor.
+ * @param max_bias     The maximum bias value used in the Alibi mechanism.
+ * @param dst          The destination tensor object for additional metadata.
+ *
+ * The function performs the following steps:
+ * 1. Calculates the logarithm floor of the number of heads to determine the
+      base for bias calculation.
+ * 2. Initializes arrays with arithmetic sequences and fills them with bias
+      values.
+ * 3. Computes the bias tensor based on the calculated biases and arithmetic
+      sequences.
+ * 4. Reshapes the bias tensor to match the dimensions of the input tensors.
+ * 5. Multiplies the position tensor by the bias tensor.
+ * 6. Adds the result of the multiplication to the source tensor to produce the
+      final output.
+ */
+static void aclnn_alibi(ggml_backend_cann_context& ctx, aclTensor* acl_src,
+                        aclTensor* acl_position, aclTensor* acl_dst,
+                        const int n_head, int64_t* src_ne, const size_t src_nb0,
+                        float max_bias, ggml_tensor* dst) {
+    const int64_t ne2_ne3 = src_ne[2] * src_ne[3];
+    GGML_ASSERT(src_nb0 == sizeof(float));
+    GGML_ASSERT(n_head == src_ne[2]);
+
+    const int n_heads_log2_floor = 1u << (uint32_t)floor(log2(n_head));
+
+    float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
+    float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
+
+    // init arange
+    ggml_cann_pool_alloc arange_allocator(ctx.pool(),
+                                          ne2_ne3 * ggml_type_size(dst->type));
+    void* tmp_arange_buffer = arange_allocator.get();
+
+    // arange1: [1, ..., n_heads_log2_floor+1)
+    float start = 1;
+    float stop = n_heads_log2_floor + 1;
+    float step = 1;
+    int64_t n_elements_arange = n_heads_log2_floor;
+
+    int64_t tmp_arange1_ne[] = {n_heads_log2_floor};
+    size_t tmp_arange1_nb[] = {sizeof(dst->type)};
+    aclTensor* tmp_arange1_tensor = ggml_cann_create_tensor(
+        tmp_arange_buffer, ggml_cann_type_mapping(dst->type),
+        ggml_type_size(dst->type), tmp_arange1_ne, tmp_arange1_nb,
+        GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+
+    aclnn_arange(ctx, tmp_arange1_tensor, start, stop, step, n_elements_arange);
+
+    aclTensor* tmp_arange2_tensor = nullptr;
+    if (n_heads_log2_floor < ne2_ne3) {
+        // arange2: [1, ..., 2 * (k - n_heads_log2_floor) + 1)
+        start = 1;
+        stop = 2 * (ne2_ne3 - n_heads_log2_floor) + 1;
+        step = 2;
+        n_elements_arange = ne2_ne3 - n_heads_log2_floor;
+        int64_t tmp_arange2_ne[] = {ne2_ne3 - n_heads_log2_floor};
+        size_t tmp_arange2_nb[] = {sizeof(dst->type)};
+
+        aclTensor* tmp_arange2_tensor = ggml_cann_create_tensor(
+            (char*)tmp_arange_buffer +
+                n_heads_log2_floor * ggml_type_size(dst->type),
+            ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
+            tmp_arange2_ne, tmp_arange2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+        aclnn_arange(ctx, tmp_arange2_tensor, start, stop, step,
+                     n_elements_arange);
+    }
+
+    // init mk_base
+    ggml_cann_pool_alloc mk_base_allocator(ctx.pool(),
+                                           ne2_ne3 * ggml_type_size(dst->type));
+    void* tmp_mk_base_buffer = mk_base_allocator.get();
+    int64_t tmp_mk_base1_ne[] = {n_heads_log2_floor};
+    size_t tmp_mk_base1_nb[] = {sizeof(dst->type)};
+    aclTensor* tmp_mk_base1_tensor = ggml_cann_create_tensor(
+        tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type),
+        ggml_type_size(dst->type), tmp_mk_base1_ne, tmp_mk_base1_nb,
+        GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+
+    aclnn_fill_scalar(ctx, m0, tmp_mk_base1_tensor);
+
+    aclTensor* tmp_mk_base2_tensor = nullptr;
+    if (n_heads_log2_floor < ne2_ne3) {
+        int64_t tmp_mk_base2_ne[] = {ne2_ne3 - n_heads_log2_floor};
+        size_t tmp_mk_base2_nb[] = {sizeof(dst->type)};
+        aclTensor* tmp_mk_base2_tensor = ggml_cann_create_tensor(
+            (char*)tmp_mk_base_buffer +
+                n_heads_log2_floor * ggml_type_size(dst->type),
+            ggml_cann_type_mapping(dst->type), ggml_type_size(dst->type),
+            tmp_mk_base2_ne, tmp_mk_base2_nb, GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+        aclnn_fill_scalar(ctx, m1, tmp_mk_base2_tensor);
+    }
+
+    // init mk
+    int64_t tmp_mk_base_ne[] = {ne2_ne3};
+    size_t tmp_mk_base_nb[] = {sizeof(dst->type)};
+    aclTensor* tmp_mk_base_tensor = ggml_cann_create_tensor(
+        tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type),
+        ggml_type_size(dst->type), tmp_mk_base_ne, tmp_mk_base_nb,
+        GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+    aclTensor* tmp_arange_tensor = ggml_cann_create_tensor(
+        tmp_arange_buffer, ggml_cann_type_mapping(dst->type),
+        ggml_type_size(dst->type), tmp_mk_base_ne, tmp_mk_base_nb,
+        GGML_MAX_DIMS - 3, ACL_FORMAT_ND);
+    aclnn_pow_tensor_tensor(ctx, tmp_mk_base_tensor, tmp_arange_tensor);
+
+    // reshape mk
+    int64_t tmp_mk_ne[] = {1, 1, src_ne[2], src_ne[3]};
+    size_t tmp_mk_nb[GGML_MAX_DIMS];
+    tmp_mk_nb[0] = ggml_type_size(dst->type);
+    for (int i = 1; i < GGML_MAX_DIMS; i++) {
+        tmp_mk_nb[i] = tmp_mk_nb[i - 1] * tmp_mk_ne[i - 1];
+    }
+    aclTensor* tmp_mk_tensor = ggml_cann_create_tensor(
+        tmp_mk_base_buffer, ggml_cann_type_mapping(dst->type),
+        ggml_type_size(dst->type), tmp_mk_ne, tmp_mk_nb, GGML_MAX_DIMS,
+        ACL_FORMAT_ND);
+
+    // acl_position * mk
+    int64_t tmp_output_ne[] = {src_ne[0], src_ne[1], src_ne[2], src_ne[3]};
+    size_t tmp_output_nb[GGML_MAX_DIMS];
+    tmp_output_nb[0] = ggml_type_size(dst->type);
+    for (int i = 1; i < GGML_MAX_DIMS; i++) {
+        tmp_output_nb[i] = tmp_output_nb[i - 1] * tmp_output_ne[i - 1];
+    }
+    ggml_cann_pool_alloc output_allocator(ctx.pool(), ggml_nbytes(dst));
+    void* tmp_output_buffer = output_allocator.get();
+    aclTensor* tmp_output_tensor = ggml_cann_create_tensor(
+        tmp_output_buffer, ggml_cann_type_mapping(dst->type),
+        ggml_type_size(dst->type), tmp_output_ne, tmp_output_nb, GGML_MAX_DIMS,
+        ACL_FORMAT_ND);
+    aclnn_mul(ctx, acl_position, tmp_mk_tensor, tmp_output_tensor);
+
+    // add
+    aclnn_add(ctx, tmp_output_tensor, acl_src, acl_dst);
+
+    ACL_CHECK(aclDestroyTensor(tmp_arange1_tensor));
+    ACL_CHECK(aclDestroyTensor(tmp_arange2_tensor));
+    ACL_CHECK(aclDestroyTensor(tmp_mk_base1_tensor));
+    ACL_CHECK(aclDestroyTensor(tmp_mk_base2_tensor));
+    ACL_CHECK(aclDestroyTensor(tmp_mk_base_tensor));
+    ACL_CHECK(aclDestroyTensor(tmp_arange_tensor));
+    ACL_CHECK(aclDestroyTensor(tmp_mk_tensor));
+    ACL_CHECK(aclDestroyTensor(tmp_output_tensor));
+}
+
+void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+    ggml_cann_dup(ctx, dst);
+}
+
+/**
+ * @brief Performs element-wise addition of two tensors in place.
+ *
+ * This function adds the source tensor `acl_src` to the destination tensor
+ * `acl_dst` element-wise and stores the result in the destination tensor
+ * `acl_dst`.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_src The source tensor to be added.
+ * @param acl_dst The destination tensor which will hold the result of the
+ * addition.
+ */
+static void aclnn_inplace_add(ggml_backend_cann_context& ctx,
+                              aclTensor* acl_src, aclTensor* acl_dst) {
+    aclScalar* alpha = nullptr;
+    float alphaValue = 1.0f;
+    alpha = aclCreateScalar(&alphaValue, aclDataType::ACL_FLOAT);
+
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    ACL_CHECK(aclnnInplaceAddGetWorkspaceSize(acl_dst, acl_src, alpha,
+                                              &workspaceSize, &executor));
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    ACL_CHECK(
+        aclnnInplaceAdd(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+    ACL_CHECK(aclDestroyScalar(alpha));
+}
+
+/**
+ * @brief Applies the softmax function to a tensor along a specified dimension.
+ *
+ * This function computes the softmax of the source tensor `acl_src` along the
+ * specified dimension `dim` and stores the result in the destination tensor
+ * `acl_dst`.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_src The source tensor on which the softmax function will be
+ * applied.
+ * @param dim The dimension along which the softmax function will be computed.
+ * @param acl_dst The destination tensor where the softmax results will be
+ * stored.
+ */
+static void aclnn_softmax(ggml_backend_cann_context& ctx, aclTensor* acl_src,
+                          int64_t dim, aclTensor* acl_dst) {
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    ACL_CHECK(aclnnSoftmaxGetWorkspaceSize(acl_src, dim, acl_dst,
+                                           &workspaceSize, &executor));
+
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    aclrtStream stream = ctx.stream();
+    ACL_CHECK(aclnnSoftmax(workspaceAddr, workspaceSize, executor, stream));
+}
+
+void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+    ggml_tensor* src0 = dst->src[0];
+    ggml_tensor* src1 = dst->src[1];  // mask
+
+    aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
+    aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+
+    float scale = 1.0f;
+    float max_bias = 0.0f;
+
+    memcpy(&scale, (float*)dst->op_params + 0, sizeof(float));
+    memcpy(&max_bias, (float*)dst->op_params + 1, sizeof(float));
+
+    // input mul scale
+    aclScalar* acl_scale = aclCreateScalar(&scale, aclDataType::ACL_FLOAT);
+
+    size_t n_bytes = ggml_nbytes(src0);
+    ggml_cann_pool_alloc mul_scale_allocator(ctx.pool(), n_bytes);
+    void* input_mul_scale_buffer = mul_scale_allocator.get();
+    aclTensor* acl_input_mul_scale_tensor = ggml_cann_create_tensor(
+        input_mul_scale_buffer, ACL_FLOAT, ggml_type_size(src0->type), src0->ne,
+        src0->nb, GGML_MAX_DIMS);
+
+    bool inplace = false;
+    aclnn_muls(ctx, acl_src0, scale, acl_input_mul_scale_tensor, inplace);
+
+    // mask
+    aclTensor* acl_src1_fp32_tensor = nullptr;
+    aclTensor* tmp_mask_tensor = nullptr;
+    ggml_cann_pool_alloc src1_fp32_allocator(ctx.pool());
+    if (src1) {
+        const bool use_f16 = src1->type == GGML_TYPE_F16;
+        if (use_f16) {
+            // cast to fp32
+            size_t n_bytes = ggml_nelements(src1) * sizeof(float_t);
+            size_t src1_fp32_nb[GGML_MAX_DIMS];
+            src1_fp32_nb[0] = sizeof(float_t);
+            for (int i = 1; i < GGML_MAX_DIMS; i++) {
+                src1_fp32_nb[i] = src1_fp32_nb[i - 1] * src1->ne[i - 1];
+            }
+            src1_fp32_allocator.alloc(n_bytes);
+            void* src1_fp32_buffer = src1_fp32_allocator.get();
+            acl_src1_fp32_tensor = ggml_cann_create_tensor(
+                src1_fp32_buffer, ACL_FLOAT, sizeof(float), src1->ne,
+                src1_fp32_nb, GGML_MAX_DIMS);
+            aclTensor* acl_src1 = ggml_cann_create_tensor(src1);
+            aclnn_cast(ctx, acl_src1, acl_src1_fp32_tensor, ACL_FLOAT);
+
+            ACL_CHECK(aclDestroyTensor(acl_src1));
+        } else {
+            acl_src1_fp32_tensor = ggml_cann_create_tensor(src1);
+        }
+
+        // broadcast the mask across rows, only use ne11 of ne01 in mask
+        if (src1->ne[1] != src0->ne[1]) {
+            // mask shape: [1,1,ne11,ne10]
+            int64_t tmp_mask_ne[] = {src0->ne[0], src0->ne[1], 1, 1};
+            size_t tmp_mask_nb[GGML_MAX_DIMS];
+            tmp_mask_nb[0] = sizeof(float_t);
+            for (int i = 1; i < GGML_MAX_DIMS; i++) {
+                tmp_mask_nb[i] = tmp_mask_nb[i - 1] * tmp_mask_ne[i - 1];
+            }
+            tmp_mask_tensor = ggml_cann_create_tensor(
+                src1->data, ACL_FLOAT, sizeof(float), tmp_mask_ne, tmp_mask_nb,
+                GGML_MAX_DIMS, ACL_FORMAT_ND);
+        }
+
+        // alibi
+        const int n_head = src0->ne[2];
+        const size_t src_nb0 = src0->nb[0];
+
+        n_bytes = ggml_nbytes(dst);
+        ggml_cann_pool_alloc output_allocator(ctx.pool(), n_bytes);
+        void* output_buffer = output_allocator.get();
+        aclTensor* alibi_output_tensor = ggml_cann_create_tensor(
+            output_buffer, ACL_FLOAT, ggml_type_size(dst->type), dst->ne,
+            dst->nb, GGML_MAX_DIMS);
+        if (max_bias <= 0.0f) {
+            // slope = 1.0
+            if (tmp_mask_tensor) {
+                aclnn_add(ctx, tmp_mask_tensor, acl_input_mul_scale_tensor,
+                          alibi_output_tensor);
+            } else {
+                aclnn_add(ctx, acl_src1_fp32_tensor, acl_input_mul_scale_tensor,
+                          alibi_output_tensor);
+            }
+        } else {
+            // slope != 1.0
+            if (tmp_mask_tensor) {
+                aclnn_alibi(ctx, acl_input_mul_scale_tensor, tmp_mask_tensor,
+                            alibi_output_tensor, n_head, src0->ne, src_nb0,
+                            max_bias, dst);
+            } else {
+                aclnn_alibi(ctx, acl_input_mul_scale_tensor,
+                            acl_src1_fp32_tensor, alibi_output_tensor, n_head,
+                            src0->ne, src_nb0, max_bias, dst);
+            }
+        }
+
+        // softmax
+        aclnn_softmax(ctx, alibi_output_tensor, 3, acl_dst);
+        ACL_CHECK(aclDestroyTensor(alibi_output_tensor));
+    } else {
+        aclnn_softmax(ctx, acl_input_mul_scale_tensor, 3, acl_dst);
+    }
+
+    ACL_CHECK(aclDestroyTensor(acl_src0));
+    ACL_CHECK(aclDestroyTensor(acl_src1_fp32_tensor));
+    ACL_CHECK(aclDestroyTensor(acl_dst));
+    ACL_CHECK(aclDestroyScalar(acl_scale));
+    ACL_CHECK(aclDestroyTensor(acl_input_mul_scale_tensor));
+    ACL_CHECK(aclDestroyTensor(tmp_mask_tensor));
+}
+
+void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+    ggml_tensor* src0 = dst->src[0];
+    ggml_tensor* src1 = dst->src[1];
+
+    ggml_cann_pool_alloc src0_extra_allocator(ctx.pool(), sizeof(ggml_tensor));
+    ggml_cann_pool_alloc src1_extra_allocator(ctx.pool(), sizeof(ggml_tensor));
+    ggml_cann_pool_alloc dst_extra_allocator(ctx.pool(), sizeof(ggml_tensor));
+    src0->extra = src0_extra_allocator.get();
+    src1->extra = src1_extra_allocator.get();
+    dst->extra = dst_extra_allocator.get();
+    ACL_CHECK(aclrtMemcpyAsync(src0->extra, sizeof(ggml_tensor), src0,
+                               sizeof(ggml_tensor), ACL_MEMCPY_HOST_TO_DEVICE,
+                               ctx.stream()));
+    ACL_CHECK(aclrtMemcpyAsync(src1->extra, sizeof(ggml_tensor), src1,
+                               sizeof(ggml_tensor), ACL_MEMCPY_HOST_TO_DEVICE,
+                               ctx.stream()));
+    ACL_CHECK(aclrtMemcpyAsync(dst->extra, sizeof(ggml_tensor), dst,
+                               sizeof(ggml_tensor), ACL_MEMCPY_HOST_TO_DEVICE,
+                               ctx.stream()));
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            aclrtlaunch_ascendc_get_row_f32(
+                24, ctx.stream(), src0->data, src1->data, dst->data,
+                ((ggml_tensor*)src0->extra)->ne,
+                ((ggml_tensor*)src0->extra)->nb,
+                ((ggml_tensor*)src1->extra)->ne,
+                ((ggml_tensor*)src1->extra)->nb, ((ggml_tensor*)dst->extra)->ne,
+                ((ggml_tensor*)dst->extra)->nb);
+            break;
+        case GGML_TYPE_F16:
+            aclrtlaunch_ascendc_get_row_f16(
+                24, ctx.stream(), src0->data, src1->data, dst->data,
+                ((ggml_tensor*)src0->extra)->ne,
+                ((ggml_tensor*)src0->extra)->nb,
+                ((ggml_tensor*)src1->extra)->ne,
+                ((ggml_tensor*)src1->extra)->nb, ((ggml_tensor*)dst->extra)->ne,
+                ((ggml_tensor*)dst->extra)->nb);
+            break;
+        case GGML_TYPE_Q4_0:
+            aclrtlaunch_ascendc_get_row_q4_0(
+                24, ctx.stream(), src0->data, src1->data, dst->data,
+                ((ggml_tensor*)src0->extra)->ne,
+                ((ggml_tensor*)src1->extra)->ne,
+                ((ggml_tensor*)src1->extra)->nb, ((ggml_tensor*)dst->extra)->ne,
+                ((ggml_tensor*)dst->extra)->nb);
+            break;
+        case GGML_TYPE_Q8_0:
+            aclrtlaunch_ascendc_get_row_q8_0(
+                24, ctx.stream(), src0->data, src1->data, dst->data,
+                ((ggml_tensor*)src0->extra)->ne,
+                ((ggml_tensor*)src1->extra)->ne,
+                ((ggml_tensor*)src1->extra)->nb, ((ggml_tensor*)dst->extra)->ne,
+                ((ggml_tensor*)dst->extra)->nb);
+            break;
+        default:
+            GGML_ABORT("fatal error");
+            break;
+    }
+}
+
+/**
+ * @brief Repeats elements of a tensor along a specified dimension.
+ *
+ * This function repeats each element of the source tensor `acl_src` a specified
+ * number of times (`repeats`) along the specified dimension `dim` and stores
+ * the result in the destination tensor `acl_dst`.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_src The source tensor whose elements will be repeated.
+ * @param acl_dst The destination tensor where the repeated elements will be
+ * stored.
+ * @param dim The dimension along which the elements will be repeated.
+ * @param repeats The number of times each element will be repeated.
+ * @param output_size The size of the output tensor.
+ */
+static void aclnn_repeat_interleave(ggml_backend_cann_context& ctx,
+                                    aclTensor* acl_src, aclTensor* acl_dst,
+                                    int64_t dim, int64_t repeats,
+                                    int64_t output_size) {
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    ACL_CHECK(aclnnRepeatInterleaveIntWithDimGetWorkspaceSize(
+        acl_src, repeats, dim, output_size, acl_dst, &workspaceSize,
+        &executor));
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    ACL_CHECK(aclnnRepeatInterleaveIntWithDim(workspaceAddr, workspaceSize,
+                                              executor, ctx.stream()));
+}
+
+/**
+ * @brief Performs matrix multiplication of two tensors.
+ *
+ * This function computes the matrix multiplication of the input tensor
+ * `acl_input` and the weight tensor `acl_weight`, and stores the result in the
+ * destination tensor `acl_dst`.
+ * The operation is defined as:
+ * \f[
+ *     \text {acl_dst}=\text {acl_input@acl_weight}
+ * \f]
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_input The input tensor for the matrix multiplication.
+ * @param acl_weight The weight tensor for the matrix multiplication.
+ * @param acl_dst The destination tensor where the result of the matrix
+ * multiplication will be stored.
+ */
+static void aclnn_mat_mul(ggml_backend_cann_context& ctx, aclTensor* acl_input,
+                          aclTensor* acl_weight, aclTensor* acl_dst) {
+    int8_t cube_math_type = 1;  // ALLOW_FP32_DOWN_PRECISION, when input is
+                                // fp32, atlas a2 will transpose it to HFLOAT32.
+
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    ACL_CHECK(aclnnMatmulGetWorkspaceSize(acl_input, acl_weight, acl_dst,
+                                          cube_math_type, &workspaceSize,
+                                          &executor));
+
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    ACL_CHECK(
+        aclnnMatmul(workspaceAddr, workspaceSize, executor, ctx.stream()));
+}
+
+/**
+ * @brief Performs matrix multiplication with floating-point precision on
+ * tensors using the CANN backend.
+ *
+ * This function performs matrix multiplication of the input tensor and the
+ * weight tensor, handling broadcasting and transposing as needed, and stores
+ * the result in the destination tensor `dst`.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param dst The destination tensor where the result of the matrix
+ * multiplication will be stored.
+ */
+static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx,
+                                 ggml_tensor* dst) {
+    ggml_tensor* weight = dst->src[0];  // weight
+    ggml_tensor* input = dst->src[1];   // input
+
+    // when weight ne2 or ne3 is 1, aclnnMatmulGetWorkspaceSize will auto
+    // broadcast, when weight ne2 or ne3 is not 1, weight need repeat.
+    BCAST_MUL_MAT_SHAPE(input, weight, dst);
+
+    // transpose weight: [1,2,3,4] -> [1,2,4,3]
+    int64_t transpose_ne[] = {bcast_weight_ne[1], bcast_weight_ne[0],
+                              bcast_weight_ne[2], bcast_weight_ne[3],
+                              bcast_weight_ne[4], bcast_weight_ne[5]};
+    size_t transpose_nb[] = {bcast_weight_nb[1], bcast_weight_nb[0],
+                             bcast_weight_nb[2], bcast_weight_nb[3],
+                             bcast_weight_nb[4], bcast_weight_nb[5]};
+
+    aclTensor* acl_weight_tensor =
+        ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, bcast_dims);
+    aclTensor* acl_input_tensor =
+        ggml_cann_create_tensor(input, BCAST_MUL_MAT_PARAM(input));
+    aclTensor* acl_dst = ggml_cann_create_tensor(dst, BCAST_MUL_MAT_PARAM(dst));
+    aclnn_mat_mul(ctx, acl_input_tensor, acl_weight_tensor, acl_dst);
+
+    ACL_CHECK(aclDestroyTensor(acl_weight_tensor));
+    ACL_CHECK(aclDestroyTensor(acl_input_tensor));
+    ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+/**
+ * @brief Performs matrix multiplication with quantized weights and
+ * floating-point inputs using the CANN backend.
+ *
+ * This function performs matrix multiplication of the input tensor `src1` and
+ * the weight tensor `src0`, handling broadcasting, transposing, and
+ * quantization as needed, and stores the result in the destination tensor
+ * `dst`.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param dst The destination tensor where the result of the matrix
+ * multiplication will be stored.
+ */
+static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
+                                   ggml_tensor* dst,
+                                   const enum ggml_type type) {
+    ggml_tensor* src0 = dst->src[0];  // weight
+    ggml_tensor* src1 = dst->src[1];  // input
+
+    // The shape of the weight is NCHW. Matrix multiplication uses HW dims. HC
+    // is regarded as batch. weight need transpose.
+    int64_t weight_ne[] = {src0->ne[1], src0->ne[0]};
+    float weight_elem_size;
+    if (type == GGML_TYPE_Q4_0) {
+        weight_elem_size = float(sizeof(uint8_t)) / 2;
+    }
+    else if (type == GGML_TYPE_Q8_0) {
+        weight_elem_size = float(sizeof(uint8_t));
+    }
+    else {
+        GGML_ABORT("Only support Q4_0 and Q8_0 MUL_MAT");
+    }
+    float weight_nb[] = {weight_elem_size * src0->ne[0], weight_elem_size};
+
+    // size of one matrix is element_size * height * width.
+    size_t weight_stride = weight_elem_size * src0->ne[0] * src0->ne[1];
+    size_t weight_size = weight_stride * src0->ne[2] * src0->ne[3];
+
+    // scale stored at the end of weight. Also need transpose.
+    GGML_ASSERT(QK4_0 == QK8_0);
+    int64_t scale_ne[] = {src0->ne[1], src0->ne[0] / QK8_0};
+    size_t scale_elem_size = sizeof(uint16_t);
+    size_t scale_nb[] = {src0->ne[0] / QK8_0 * scale_elem_size,
+                         scale_elem_size};
+    size_t scale_stride = scale_elem_size * src0->ne[0] * src0->ne[1] / QK8_0;
+    char* scale_offset = (char*)src0->data + weight_size;
+
+    // input
+    void* input_buffer;
+    size_t input_elem_size = sizeof(uint16_t);
+    int64_t input_ne[] = {src1->ne[0], src1->ne[1]};
+    size_t input_nb[] = {input_elem_size, input_elem_size * src1->ne[0]};
+    size_t input_stride = input_elem_size * src1->ne[0] * src1->ne[1];
+
+    ggml_cann_pool_alloc input_alloctor(ctx.pool());
+    if (src1->type != GGML_TYPE_F16) {
+        aclTensor* acl_src1_tensor = ggml_cann_create_tensor(src1);
+        input_alloctor.alloc(ggml_nelements(src1) * input_elem_size);
+        input_buffer = input_alloctor.get();
+
+        int64_t* input_cast_ne = src1->ne;
+        size_t input_cast_nb[GGML_MAX_DIMS];
+        input_cast_nb[0] = sizeof(uint16_t);
+        for (int i = 1; i < GGML_MAX_DIMS; i++) {
+            input_cast_nb[i] = input_cast_nb[i - 1] * input_cast_ne[i - 1];
+        }
+
+        aclTensor* acl_input_tensor = ggml_cann_create_tensor(
+            input_buffer, ACL_FLOAT16, input_elem_size, input_cast_ne,
+            input_cast_nb, GGML_MAX_DIMS);
+        aclnn_cast(ctx, acl_src1_tensor, acl_input_tensor, ACL_FLOAT16);
+        ACL_CHECK(aclDestroyTensor(acl_input_tensor));
+        ACL_CHECK(aclDestroyTensor(acl_src1_tensor));
+    } else {
+        input_buffer = src1->data;
+    }
+
+    // output
+    size_t output_elem_size = sizeof(uint16_t);
+    int64_t output_ne[] = {dst->ne[0], dst->ne[1]};
+    size_t output_nb[] = {output_elem_size, output_elem_size * dst->ne[0]};
+    ggml_cann_pool_alloc output_alloctor(
+        ctx.pool(), ggml_nelements(dst) * output_elem_size);
+    void* output_buffer = output_alloctor.get();
+    size_t output_stride = output_elem_size * dst->ne[0] * dst->ne[1];
+
+    // aclnn
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    for (int64_t n1 = 0; n1 < src1->ne[3]; n1++) {
+        for (int64_t c1 = 0; c1 < src1->ne[2]; c1++) {
+            int64_t n0 = n1 / (src1->ne[3] / src0->ne[3]);
+            int64_t c0 = c1 / (src1->ne[2] / src0->ne[2]);
+
+            int64_t batch1 = n1 * src1->ne[2] + c1;
+            int64_t batch0 = n0 * src0->ne[2] + c0;
+
+            aclTensor* acl_input_tensor = ggml_cann_create_tensor(
+                (char*)input_buffer + batch1 * input_stride, ACL_FLOAT16,
+                input_elem_size, input_ne, input_nb, 2);
+            aclTensor* acl_weight_tensor = ggml_cann_create_tensor(
+                (char*)src0->data + batch0 * weight_stride,
+                ggml_cann_type_mapping(type), weight_elem_size, weight_ne,
+                weight_nb, 2);
+            aclTensor* acl_scale_tensor = ggml_cann_create_tensor(
+                scale_offset + batch0 * scale_stride, ACL_FLOAT16,
+                scale_elem_size, scale_ne, scale_nb, 2);
+            aclTensor* acl_output_tensor = ggml_cann_create_tensor(
+                (char*)output_buffer + batch1 * output_stride, ACL_FLOAT16,
+                output_elem_size, output_ne, output_nb, 2);
+
+            ACL_CHECK(aclnnWeightQuantBatchMatmulV2GetWorkspaceSize(
+                acl_input_tensor, acl_weight_tensor, acl_scale_tensor, nullptr,
+                nullptr, nullptr, nullptr, QK8_0, acl_output_tensor,
+                &workspaceSize, &executor));
+
+            if (workspaceSize > 0 && workspaceAddr == nullptr) {
+                ggml_cann_pool_alloc workspace_allocator(ctx.pool(),
+                                                         workspaceSize);
+                workspaceAddr = workspace_allocator.get();
+            }
+
+            ACL_CHECK(aclnnWeightQuantBatchMatmulV2(
+                workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+            ACL_CHECK(aclDestroyTensor(acl_input_tensor));
+            ACL_CHECK(aclDestroyTensor(acl_weight_tensor));
+            ACL_CHECK(aclDestroyTensor(acl_scale_tensor));
+            ACL_CHECK(aclDestroyTensor(acl_output_tensor));
+        }
+    }
+
+    // cast out
+    int64_t* output_cast_ne = dst->ne;
+    size_t output_cast_nb[GGML_MAX_DIMS];
+    output_cast_nb[0] = sizeof(uint16_t);
+    for (int i = 1; i < GGML_MAX_DIMS; i++) {
+        output_cast_nb[i] = output_cast_nb[i - 1] * output_cast_ne[i - 1];
+    }
+
+    aclTensor* acl_output_tensor =
+        ggml_cann_create_tensor(output_buffer, ACL_FLOAT16, output_elem_size,
+                                output_cast_ne, output_cast_nb, GGML_MAX_DIMS);
+    aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
+    aclnn_cast(ctx, acl_output_tensor, acl_dst_tensor, ACL_FLOAT);
+
+    ACL_CHECK(aclDestroyTensor(acl_output_tensor));
+    ACL_CHECK(aclDestroyTensor(acl_dst_tensor));
+}
+
+void ggml_cann_mul_mat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+    const enum ggml_type type = dst->src[0]->type;
+    switch (type) {
+        case GGML_TYPE_F32:
+        case GGML_TYPE_F16:
+            ggml_cann_mat_mul_fp(ctx, dst);
+            break;
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q8_0:
+            ggml_cann_mul_mat_quant(ctx, dst, type);
+            break;
+        default:
+            GGML_ABORT("fatal error");
+            break;
+    }
+}
+
+/**
+ * @brief Rolls the elements of a tensor along a specified dimension.
+ *
+ * This function rolls the elements of the source tensor `acl_src` by the
+ * specified shifts `shifts` along the specified dimensions `dims`, and stores
+ * the result in the destination tensor `acl_dst`.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_src The source tensor whose elements will be rolled.
+ * @param acl_dst The destination tensor where the rolled elements will be
+ * stored.
+ * @param shifts An array specifying the number of positions by which elements
+ * are shifted.
+ * @param dims An array specifying the dimensions along which elements are
+ * shifted.
+ */
+static void aclnn_roll(ggml_backend_cann_context& ctx, aclTensor* acl_src,
+                       aclTensor* acl_dst, int64_t* shifts, int64_t* dims) {
+    aclIntArray* acl_shifts = aclCreateIntArray(shifts, 1);
+    aclIntArray* acl_dims = aclCreateIntArray(dims, 1);
+
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    ACL_CHECK(aclnnRollGetWorkspaceSize(acl_src, acl_shifts, acl_dims, acl_dst,
+                                        &workspaceSize, &executor));
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    ACL_CHECK(aclnnRoll(workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+    ACL_CHECK(aclDestroyIntArray(acl_shifts));
+    ACL_CHECK(aclDestroyIntArray(acl_dims));
+}
+
+/**
+ * @brief Fills specified positions of a tensor with a scalar value.
+ *
+ * This function fills the positions in the source tensor `acl_src` specified by
+ * `index` along the dimension `dim` with the scalar value `value`.
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_src The source tensor where the positions will be filled.
+ * @param dim The dimension along which the positions are specified.
+ * @param index An array specifying the positions to be filled.
+ * @param index_num The number of positions specified in the index array.
+ * @param value The scalar value used to fill the specified positions.
+ */
+static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
+                                    aclTensor* acl_src, int64_t dim,
+                                    int64_t* index, int64_t index_num,
+                                    float value) {
+    aclIntArray* acl_index = aclCreateIntArray(index, index_num);
+    aclScalar* acl_value = aclCreateScalar(&value, aclDataType::ACL_FLOAT);
+
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    ACL_CHECK(aclnnInplaceIndexFillTensorGetWorkspaceSize(
+        acl_src, dim, acl_index, acl_value, &workspaceSize, &executor));
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    ACL_CHECK(aclnnInplaceIndexFillTensor(workspaceAddr, workspaceSize,
+                                          executor, ctx.stream()));
+
+    ACL_CHECK(aclDestroyIntArray(acl_index));
+    ACL_CHECK(aclDestroyScalar(acl_value));
+}
+
+static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
+                             aclTensor* acl_cos_repeat_tensor,
+                             aclTensor* acl_sin_repeat_tensor,
+                             float theta_scale, bool is_neox) {
+    // int sin/cos cache, cache has different repeat method depond on
+    // @param.is_neox
+
+    ggml_tensor* src0 = dst->src[0];  // input
+    ggml_tensor* src1 = dst->src[1];  // position
+
+    // arange, [0,1,...,ne0/2]
+    int64_t arange_length = src0->ne[0] / 2;
+    ggml_cann_pool_alloc arange_allocator(ctx.pool(),
+                                          arange_length * sizeof(float_t));
+    void* arange_buffer = arange_allocator.get();
+    int64_t arange_ne[] = {arange_length, 1, 1, 1};
+    size_t arange_nb[] = {sizeof(float_t), sizeof(float_t), sizeof(float_t),
+                          arange_length * sizeof(float_t)};
+
+    aclTensor* acl_arange_tensor =
+        ggml_cann_create_tensor(arange_buffer, ACL_FLOAT, sizeof(float_t),
+                                arange_ne, arange_nb, GGML_MAX_DIMS);
+    float start = 0;
+    float step = 1;
+    float stop = src0->ne[0] / 2;
+    float n_elements = src0->ne[0] / 2;
+    aclnn_arange(ctx, acl_arange_tensor, start, stop, step, n_elements);
+
+    // power
+    // aclnnPowScalarTensor(): @param self is tensor which should be scalar, so
+    // use aclnn_pow_tensor_tensor() until fixed. aclScalar* acl_theta_scale =
+    // aclCreateScalar(&theta_scale, aclDataType::ACL_FLOAT);
+    // aclnn_power_scalar_tensor(ctx, acl_theta_scale, acl_arange_tensor,
+    // acl_power_tensor);
+    ggml_cann_pool_alloc theta_scale_allocator(ctx.pool(),
+                                               arange_length * sizeof(float_t));
+    void* theta_scale_buffer = theta_scale_allocator.get();
+    aclTensor* acl_theta_scale_tensor = aclnn_ones(
+        ctx, theta_scale_buffer, arange_length * sizeof(float_t), arange_ne,
+        GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t), theta_scale);
+    aclnn_pow_tensor_tensor(ctx, acl_theta_scale_tensor, acl_arange_tensor);
+
+    // position
+    GGML_ASSERT(src1->type == GGML_TYPE_I32);
+    int64_t position_length = src1->ne[0];
+    int64_t position_ne[] = {1, position_length, 1, 1};
+    size_t position_nb[] = {sizeof(int32_t), sizeof(int32_t),
+                            sizeof(int32_t) * position_length,
+                            sizeof(int32_t) * position_length};
+    aclTensor* acl_position_tensor = ggml_cann_create_tensor(
+        src1->data, ggml_cann_type_mapping(src1->type),
+        ggml_type_size(src1->type), position_ne, position_nb, GGML_MAX_DIMS);
+
+    // power * position
+    int64_t theta_length = arange_length * position_length;
+    ggml_cann_pool_alloc theta_allocator(ctx.pool(),
+                                         theta_length * sizeof(float_t));
+    void* theta_buffer = theta_allocator.get();
+    int64_t theta_ne[] = {arange_length, position_length, 1, 1};
+    size_t theta_nb[GGML_MAX_DIMS];
+    theta_nb[0] = sizeof(float_t);
+    for (int i = 1; i < GGML_MAX_DIMS; i++) {
+        theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1];
+    }
+    aclTensor* acl_theta_tensor =
+        ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float_t),
+                                theta_ne, theta_nb, GGML_MAX_DIMS);
+    aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor,
+              acl_theta_tensor);
+
+    // permute: [0,1,2,3]->[0,2,1,3]
+    int64_t permute_ne[] = {arange_length, 1, position_length, 1};
+    size_t permute_nb[GGML_MAX_DIMS];
+    permute_nb[0] = sizeof(float_t);
+    for (int i = 1; i < GGML_MAX_DIMS; i++) {
+        permute_nb[i] = permute_nb[i - 1] * permute_ne[i - 1];
+    }
+    ggml_cann_pool_alloc permute_allocator(ctx.pool(),
+                                           theta_length * sizeof(float_t));
+    void* permute_buffer = permute_allocator.get();
+    aclTensor* acl_permute_tensor = ggml_cann_create_tensor(
+        permute_buffer, ACL_FLOAT, sizeof(float_t), permute_ne, permute_nb,
+        GGML_MAX_DIMS, ACL_FORMAT_ND);
+    int64_t permute_dim[] = {0, 2, 1, 3};
+    int64_t num_dims = 4;
+    aclnn_permute(ctx, acl_theta_tensor, acl_permute_tensor, permute_dim,
+                  num_dims);
+
+    // sin/cos
+    ggml_cann_pool_alloc sin_allocator(ctx.pool(),
+                                       theta_length * sizeof(float_t));
+    void* sin_buffer = sin_allocator.get();
+    aclTensor* acl_sin_tensor = ggml_cann_create_tensor(
+        sin_buffer, ACL_FLOAT, sizeof(float_t), permute_ne, permute_nb,
+        GGML_MAX_DIMS, ACL_FORMAT_ND);
+    aclnn_sin(ctx, acl_permute_tensor, acl_sin_tensor);
+
+    ggml_cann_pool_alloc cos_allocator(ctx.pool(),
+                                       theta_length * sizeof(float_t));
+    void* cos_buffer = cos_allocator.get();
+    aclTensor* acl_cos_tensor = ggml_cann_create_tensor(
+        cos_buffer, ACL_FLOAT, sizeof(float_t), permute_ne, permute_nb,
+        GGML_MAX_DIMS, ACL_FORMAT_ND);
+    aclnn_cos(ctx, acl_permute_tensor, acl_cos_tensor);
+
+    // repeat
+    if (is_neox) {
+        int64_t repeatsArray[] = {1, 1, 1, 2};
+        aclnn_repeat(ctx, acl_sin_tensor, acl_sin_repeat_tensor, repeatsArray);
+        aclnn_repeat(ctx, acl_cos_tensor, acl_cos_repeat_tensor, repeatsArray);
+    } else {
+        int64_t num_repeats = 2;
+        int64_t dim = 3;
+        int64_t output_size = arange_length * num_repeats;
+        aclnn_repeat_interleave(ctx, acl_sin_tensor, acl_sin_repeat_tensor, dim,
+                                num_repeats, output_size);
+        aclnn_repeat_interleave(ctx, acl_cos_tensor, acl_cos_repeat_tensor, dim,
+                                num_repeats, output_size);
+    }
+
+    // release
+    ACL_CHECK(aclDestroyTensor(acl_arange_tensor));
+    ACL_CHECK(aclDestroyTensor(acl_theta_scale_tensor));
+    ACL_CHECK(aclDestroyTensor(acl_position_tensor));
+    ACL_CHECK(aclDestroyTensor(acl_theta_tensor));
+    ACL_CHECK(aclDestroyTensor(acl_permute_tensor));
+    ACL_CHECK(aclDestroyTensor(acl_sin_tensor));
+    ACL_CHECK(aclDestroyTensor(acl_cos_tensor));
+}
+
+void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+    // TODO: use ascendc
+    // Only test with LLAMA model.
+    ggml_tensor* src0 = dst->src[0];  // input
+    ggml_tensor* src2 = dst->src[2];  // freq_factors
+
+    // TODO: with freq_factors
+    GGML_ASSERT(src2 == NULL);
+
+    // param
+    float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
+    // const int n_past     = ((int32_t *) dst->op_params)[0];
+    const int n_dims = ((int32_t*)dst->op_params)[1];
+    const int mode = ((int32_t*)dst->op_params)[2];
+    // const int n_ctx      = ((int32_t *) dst->op_params)[3];
+    const int n_ctx_orig = ((int32_t*)dst->op_params)[4];
+
+    GGML_TENSOR_UNARY_OP_LOCALS
+
+    memcpy(&freq_base, (int32_t*)dst->op_params + 5, sizeof(float));
+    memcpy(&freq_scale, (int32_t*)dst->op_params + 6, sizeof(float));
+    memcpy(&ext_factor, (int32_t*)dst->op_params + 7, sizeof(float));
+    memcpy(&attn_factor, (int32_t*)dst->op_params + 8, sizeof(float));
+    memcpy(&beta_fast, (int32_t*)dst->op_params + 9, sizeof(float));
+    memcpy(&beta_slow, (int32_t*)dst->op_params + 10, sizeof(float));
+
+    GGML_ASSERT(n_dims <= ne0);
+    GGML_ASSERT(n_dims % 2 == 0);
+
+    // TODO: ext_factor != 0
+    GGML_ASSERT(ext_factor == 0);
+    // TODO: freq_scale != 1
+    GGML_ASSERT(freq_scale == 1);
+
+    const float theta_scale = powf(freq_base, -2.0f / n_dims);
+
+    float corr_dims[2];
+    ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast,
+                             beta_slow, corr_dims);
+
+    const bool is_neox = mode & 2;
+
+    // init cos/sin cache
+    ggml_cann_pool_alloc sin_allocator(
+        ctx.pool(), src0->ne[0] * src0->ne[2] * sizeof(float_t));
+    ggml_cann_pool_alloc cos_allocator(
+        ctx.pool(), src0->ne[0] * src0->ne[2] * sizeof(float_t));
+    void* sin_buffer = sin_allocator.get();
+    void* cos_buffer = cos_allocator.get();
+
+    int64_t sin_reshape_ne[4] = {src0->ne[0], 1, src0->ne[2], 1};
+    size_t sin_reshape_nb[GGML_MAX_DIMS];
+    sin_reshape_nb[0] = sizeof(float_t);
+    for (int i = 1; i < GGML_MAX_DIMS; i++) {
+        sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
+    }
+    aclTensor* acl_sin_reshape_tensor =
+        ggml_cann_create_tensor(sin_buffer, ACL_FLOAT, sizeof(float_t),
+                                sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
+    aclTensor* acl_cos_reshape_tensor =
+        ggml_cann_create_tensor(cos_buffer, ACL_FLOAT, sizeof(float_t),
+                                sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
+    aclnn_cache_init(ctx, dst, acl_cos_reshape_tensor, acl_sin_reshape_tensor,
+                     theta_scale, is_neox);
+
+    // roll input
+    void* input_roll_buffer;
+    aclTensor* acl_minus_one_tensor;
+    void* minus_one_scale_buffer = nullptr;
+    ggml_cann_pool_alloc roll_allocator(ctx.pool(), ggml_nbytes(src0));
+    ggml_cann_pool_alloc minus_one_scale_allocator(
+        ctx.pool(), sizeof(float_t) * src0->ne[0]);
+    if (!is_neox) {
+        // roll input: [q0,q1,q2,q3,...] -> [q1,q0,q3,q2,...]
+        input_roll_buffer = roll_allocator.get();
+        int64_t input_roll_ne[4] = {2, src0->ne[1] * (src0->ne[0] / 2),
+                                    src0->ne[2], src0->ne[3]};
+        size_t input_roll_nb[GGML_MAX_DIMS];
+        input_roll_nb[0] = ggml_type_size(src0->type);
+        for (int i = 1; i < GGML_MAX_DIMS; i++) {
+            input_roll_nb[i] = input_roll_nb[i - 1] * input_roll_ne[i - 1];
+        }
+        aclTensor* acl_input_roll_tensor = ggml_cann_create_tensor(
+            input_roll_buffer, ggml_cann_type_mapping(src0->type),
+            ggml_type_size(src0->type), input_roll_ne, input_roll_nb,
+            GGML_MAX_DIMS);
+        aclTensor* acl_input_tensor = ggml_cann_create_tensor(
+            src0->data, ggml_cann_type_mapping(src0->type),
+            ggml_type_size(src0->type), input_roll_ne, input_roll_nb,
+            GGML_MAX_DIMS);
+
+        int64_t shifts[] = {1};
+        int64_t dims[] = {3};
+        aclnn_roll(ctx, acl_input_tensor, acl_input_roll_tensor, shifts, dims);
+        ACL_CHECK(aclDestroyTensor(acl_input_roll_tensor));
+        ACL_CHECK(aclDestroyTensor(acl_input_tensor));
+
+        // init [-1, 1, -1, 1, ...]
+        minus_one_scale_buffer = minus_one_scale_allocator.get();
+
+        int64_t minus_one_ne[4] = {src0->ne[0], 1, 1, 1};
+        size_t minus_one_nb[GGML_MAX_DIMS];
+        minus_one_nb[0] = sizeof(float_t);
+        for (int i = 1; i < GGML_MAX_DIMS; i++) {
+            minus_one_nb[i] = minus_one_nb[i - 1] * minus_one_ne[i - 1];
+        }
+        acl_minus_one_tensor = aclnn_ones(
+            ctx, minus_one_scale_buffer, sizeof(float_t) * src0->ne[0],
+            minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t), 1);
+        int64_t dim = 3;
+        int64_t* index = new int64_t[src0->ne[0]];
+        for (int i = 0; i < src0->ne[0]; i++) {
+            index[i] = i / 2 * 2;
+        }
+        int64_t index_num = src0->ne[0];
+        float value = -1;
+        aclnn_index_fill_tensor(ctx, acl_minus_one_tensor, dim, index,
+                                index_num, value);
+    } else {
+        // roll input: [q0,q1,q2,...] ->
+        // [q_half,q_half+1,...,q_end,q0,q1,...q_half-1]
+        input_roll_buffer = roll_allocator.get();
+        aclTensor* acl_input_roll_tensor = ggml_cann_create_tensor(
+            input_roll_buffer, ggml_cann_type_mapping(src0->type),
+            ggml_type_size(src0->type), src0->ne, src0->nb, GGML_MAX_DIMS);
+        aclTensor* acl_input_tensor = ggml_cann_create_tensor(src0);
+
+        int64_t shifts[] = {src0->ne[0] / 2};
+        int64_t dims[] = {3};
+        aclnn_roll(ctx, acl_input_tensor, acl_input_roll_tensor, shifts, dims);
+
+        ACL_CHECK(aclDestroyTensor(acl_input_roll_tensor));
+        ACL_CHECK(aclDestroyTensor(acl_input_tensor));
+
+        // init [-1, -1, -1, 1, 1,1,...]
+        minus_one_scale_buffer = minus_one_scale_allocator.get();
+
+        int64_t minus_one_ne[4] = {src0->ne[0], 1, 1, 1};
+        size_t minus_one_nb[GGML_MAX_DIMS];
+        minus_one_nb[0] = sizeof(float_t);
+        for (int i = 1; i < GGML_MAX_DIMS; i++) {
+            minus_one_nb[i] = minus_one_nb[i - 1] * minus_one_ne[i - 1];
+        }
+        acl_minus_one_tensor = aclnn_ones(
+            ctx, minus_one_scale_buffer, sizeof(float_t) * src0->ne[0],
+            minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t), 1);
+        // -1 * first half
+        int64_t first_half_ne[4] = {src0->ne[0] / 2, 1, 1, 1};
+        size_t first_half_nb[GGML_MAX_DIMS];
+        first_half_nb[0] = sizeof(float_t);
+        for (int i = 1; i < GGML_MAX_DIMS; i++) {
+            first_half_nb[i] = first_half_nb[i - 1] * first_half_ne[i - 1];
+        }
+        aclTensor* acl_first_half_tensor = ggml_cann_create_tensor(
+            minus_one_scale_buffer, ACL_FLOAT, sizeof(float_t), first_half_ne,
+            first_half_nb, GGML_MAX_DIMS);
+        bool inplace = true;
+        float scale = -1;
+        aclnn_muls(ctx, acl_first_half_tensor, scale, nullptr, inplace);
+        ACL_CHECK(aclDestroyTensor(acl_first_half_tensor));
+    }
+
+    // TODO: n_dims < ne0
+    GGML_ASSERT(n_dims == src0->ne[0]);
+
+    // input * scale
+    ggml_cann_pool_alloc roll_mul_scale_allocator(ctx.pool(),
+                                                  ggml_nbytes(src0));
+    void* input_roll_mul_scale_buffer = roll_mul_scale_allocator.get();
+    size_t input_nb[GGML_MAX_DIMS];
+    input_nb[0] = ggml_type_size(src0->type);
+    for (int i = 1; i < GGML_MAX_DIMS; i++) {
+        input_nb[i] = input_nb[i - 1] * src0->ne[i - 1];
+    }
+    aclTensor* acl_input_roll_mul_scale_tensor = ggml_cann_create_tensor(
+        input_roll_mul_scale_buffer, ggml_cann_type_mapping(src0->type),
+        ggml_type_size(src0->type), src0->ne, input_nb, GGML_MAX_DIMS);
+    aclTensor* acl_input_roll_reshape_tensor = ggml_cann_create_tensor(
+        input_roll_buffer, ggml_cann_type_mapping(src0->type),
+        ggml_type_size(src0->type), src0->ne, input_nb, GGML_MAX_DIMS);
+
+    aclnn_mul(ctx, acl_input_roll_reshape_tensor, acl_minus_one_tensor,
+              acl_input_roll_mul_scale_tensor);
+
+    // output
+    aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
+    aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+    void* output_fp32_buffer;
+    if (src0->type == GGML_TYPE_F32) {
+        aclnn_inplace_mul(ctx, acl_src0, acl_cos_reshape_tensor);
+        aclnn_inplace_mul(ctx, acl_input_roll_mul_scale_tensor,
+                          acl_sin_reshape_tensor);
+        aclnn_add(ctx, acl_src0, acl_input_roll_mul_scale_tensor, acl_dst);
+        // TODO: ne0 != n_dims in mode2
+    } else if (src0->type == GGML_TYPE_F16) {
+        size_t input_fp32_nb[GGML_MAX_DIMS];
+        input_fp32_nb[0] = sizeof(float_t);
+        for (int i = 1; i < GGML_MAX_DIMS; i++) {
+            input_fp32_nb[i] = input_fp32_nb[i - 1] * dst->ne[i - 1];
+        }
+        ggml_cann_pool_alloc fp32_allocator1(
+            ctx.pool(), ggml_nelements(dst) * sizeof(float_t));
+        void* input_fp32_buffer1 = fp32_allocator1.get();
+        aclTensor* input_fp32_tensor1 = ggml_cann_create_tensor(
+            input_fp32_buffer1, ACL_FLOAT, sizeof(float_t), dst->ne,
+            input_fp32_nb, GGML_MAX_DIMS);
+        ggml_cann_pool_alloc fp32_allocator2(
+            ctx.pool(), ggml_nelements(dst) * sizeof(float_t));
+        void* input_fp32_buffer2 = fp32_allocator2.get();
+        aclTensor* input_fp32_tensor2 = ggml_cann_create_tensor(
+            input_fp32_buffer2, ACL_FLOAT, sizeof(float_t), dst->ne,
+            input_fp32_nb, GGML_MAX_DIMS);
+
+        ggml_cann_pool_alloc fp32_allocator(
+            ctx.pool(), ggml_nelements(dst) * sizeof(float_t));
+        output_fp32_buffer = fp32_allocator.get();
+        aclTensor* output_fp32_tensor = ggml_cann_create_tensor(
+            output_fp32_buffer, ACL_FLOAT, sizeof(float_t), dst->ne,
+            input_fp32_nb, GGML_MAX_DIMS);
+        aclnn_mul(ctx, acl_src0, acl_cos_reshape_tensor, input_fp32_tensor1);
+        aclnn_mul(ctx, acl_input_roll_mul_scale_tensor, acl_sin_reshape_tensor,
+                  input_fp32_tensor2);
+        aclnn_add(ctx, input_fp32_tensor1, input_fp32_tensor2,
+                  output_fp32_tensor);
+        aclnn_cast(ctx, output_fp32_tensor, acl_dst, ACL_FLOAT16);
+
+        ACL_CHECK(aclDestroyTensor(input_fp32_tensor1));
+        ACL_CHECK(aclDestroyTensor(input_fp32_tensor2));
+        ACL_CHECK(aclDestroyTensor(output_fp32_tensor));
+    }
+
+    ACL_CHECK(aclDestroyTensor(acl_sin_reshape_tensor));
+    ACL_CHECK(aclDestroyTensor(acl_cos_reshape_tensor));
+    ACL_CHECK(aclDestroyTensor(acl_minus_one_tensor));
+    ACL_CHECK(aclDestroyTensor(acl_input_roll_mul_scale_tensor));
+    ACL_CHECK(aclDestroyTensor(acl_input_roll_reshape_tensor));
+    ACL_CHECK(aclDestroyTensor(acl_src0));
+    ACL_CHECK(aclDestroyTensor(acl_dst));
+}
diff --git a/src/ggml-cann/aclnn_ops.h b/src/ggml-cann/aclnn_ops.h
new file mode 100644 (file)
index 0000000..680129c
--- /dev/null
@@ -0,0 +1,592 @@
+#ifndef CANN_ACLNN_OPS
+#define CANN_ACLNN_OPS
+
+/**
+ * @file    acl_tensor
+ * @brief   This file contains related functions of ggml_tensor and acl_tensor.
+ *          Contains conversion from ggml_tensor to acl_tensor, broadcast and other
+ *          functions.
+ * @author  hipudding <huafengchun@gmail.com>
+ * @author  wangshuai09 <391746016@qq.com>
+ * @date    July 15, 2024
+ *
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+ * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+ * IN THE SOFTWARE.
+ */
+
+#include <aclnnop/aclnn_add.h>
+#include <aclnnop/aclnn_arange.h>
+#include <aclnnop/aclnn_argsort.h>
+#include <aclnnop/aclnn_cat.h>
+#include <aclnnop/aclnn_clamp.h>
+#include <aclnnop/aclnn_div.h>
+#include <aclnnop/aclnn_gelu.h>
+#include <aclnnop/aclnn_hardsigmoid.h>
+#include <aclnnop/aclnn_hardswish.h>
+#include <aclnnop/aclnn_leaky_relu.h>
+#include <aclnnop/aclnn_mul.h>
+#include <aclnnop/aclnn_relu.h>
+#include <aclnnop/aclnn_silu.h>
+#include <aclnnop/aclnn_tanh.h>
+#include "acl_tensor.h"
+#include "common.h"
+
+/**
+ * @brief   Repeats a ggml tensor along each dimension to match the dimensions
+ *          of another tensor.
+ *
+ * @details This function repeats the elements of a source ggml tensor along
+ *          each dimension to create a destination tensor with the specified
+ *          dimensions. The operation is performed using the ACL backend and
+ *          executed asynchronously on the device.
+ *
+ * @param   ctx The CANN context used for operations.
+ * @param   dst The ggml tensor representing the destination, which op is
+ *              GGML_OP_REPEAT and specifies the desired dimensions.
+ */
+void ggml_cann_repeat(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief   Adds two ggml tensors using the CANN backend.
+ *
+ * @details This function performs an element-wise addition of two tensors. In
+ *          case the tensors do not have the same shape, one or both tensors
+ *          will be broadcasted to match the shape of the other before the
+ *          addition is performed.The formula for the operation is given by:
+ *          \f[
+ *              \text{dst} = \text{acl_src0} + \alpha \cdot \text{acl_src1}
+ *          \f]
+ *
+ * @param ctx The CANN context used for operations.
+ * @param dst The ggml tensor representing the destination, result of the
+ *            addition is stored at dst->data, and dst->op is `GGML_OP_ADD`
+ */
+void ggml_cann_add(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief   Applies the Leaky ReLU activation function to a tensor using the CANN
+ *          backend.
+ *
+ * @details This function computes the Leaky ReLU activation for each element of
+ *          the input tensor. The Leaky ReLU function allows a small gradient
+ *          when the unit is not active (i.e., when the input is negative). The
+ *          Leaky ReLU function is defined as:
+ *          \f[
+ *              \text{dst} = \max(0, src) + \text{negativeSlope} \cdot \min(0,
+ *               src)
+ *          \f]
+ *          `negativeSlope` is in dst->params.
+ *
+ * @param ctx The CANN context used for operations.
+ * @param dst The destination tensor where the result of the Leaky ReLU
+ *            activation is stored, which op is `GGML_OP_LEAKY_RELU`
+ */
+void ggml_cann_leaky_relu(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief    Concatenates multiple tensors along a specified dimension using the
+ *           CANN backend.
+ *
+ * @param ctx        The CANN context used for operations.
+ * @param tensorList A pointer to the list of tensors to be concatenated.
+ * @param dst        The destination tensor where the result of the
+ *                   concatenation is stored. dst->op is `GGML_OP_CONCAT`.
+ * @param concat_dim The dimension along which the tensors are concatenated.
+ *
+ * @attention tensorList length should be 2 and the dimension using for concat
+ *            default to 1.
+ */
+void ggml_cann_concat(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief   Generates a sequence of evenly spaced values within a specified
+ *          interval for a ggml tensor using the CANN backend.
+ *
+ * @details This function creates a sequence of numbers over a specified i
+ *          nterval, starting from `start`, ending before `stop`, and
+ *          incrementing by `step`. The sequence is stored in the destination
+ *          tensor `dst`.
+ *
+ * @param ctx The CANN context used for operations.
+ * @param dst The destination tensor where the generated sequence will be stored.
+ *            `start`, 'stop' and 'step' are in dst->op_params and dst->op is
+ *            `GGML_OP_ARANGE`.
+ */
+void ggml_cann_arange(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief   Computes the square of the elements of a ggml tensor using the CANN
+ *          backend.
+ * @details The function sets the second source tensor of the destination
+ *          tensor `dst` to be equal to the first source tensor. This is
+ *          effectively squaring the elements since the multiplication becomes
+ *          `element * element`.
+ * @param ctx The CANN context used for operations.
+ * @param dst The destination tensor where the squared values will be stored,
+ *            which dst->op is `GGML_OP_SQR`.
+ */
+void ggml_cann_sqr(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief   Applies a clamp operation to the elements of a ggml tensor using the
+ *          CANN backend.
+ *
+ * @details This function clamps the elements of the input tensor `src` to a
+ *          specified range defined by `min` and `max` values. The result is
+ *          stored in the destination tensor `dst`. The operation is defined as:
+ *          \f[
+ *              y = \max(\min(x, max\_value), min\_value)
+ *           \f]
+ *          where `x` is an element of the input tensor, and `y` is the
+ *          corresponding element in the output tensor.
+ * @param ctx The CANN context used for operations.
+ * @param dst The destination tensor where the clamped values will be stored.
+ *            dst->op is `GGML_OP_CLAMP`, `min` and `max` value is in dst->params.
+ */
+void ggml_cann_clamp(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief   Scales the elements of a ggml tensor by a constant factor using the
+ *          CANN backend.
+ *
+ * @details This function multiplies each element of the input tensor `src` by
+ *          a scaling factor `scale`, storing the result in the destination
+ *          tensor `dst`. The operation is defined as:
+ *          \f[
+ *             dst = src \times scale
+ *          \f]
+ *
+ * @param ctx The CANN context used for operations.
+ * @param dst The destination tensor where the scaled values will be stored.
+ *            dst->op is `GGML_OP_SCALE` and `scale` value is in dst->params.
+ */
+void ggml_cann_scale(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief   Sorts the elements of a ggml tensor and returns the indices that
+ *          would sort the tensor using the CANN backend.
+ *
+ * @details This function performs an argsort operation on the input tensor
+ *          `src`. It sorts the elements of `src` in either ascending or
+ *          descending order, depending on the `GGML_SORT_ORDER_DESC`,
+ *          and returns the indices that would sort the original tensor.
+ *
+ * @param ctx The CANN context used for operations.
+ * @param dst The destination tensor where the sorted indices will be stored.
+ *            dst->op is `GGML_OP_ARGSORT`.
+ */
+void ggml_cann_argsort(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief   Computes the Layer Normalization for a ggml tensor using the CANN
+ *          backend.
+ *
+ * @details This function applies the Layer Normalization operation on the
+ *          input tensor `src` and stores the result in the destination tensor
+ *          `dst`. Layer Normalization normalizes the features at each sample in
+ *          a mini-batch independently. It is commonly used in neural networks
+ *          to normalize the activations of a layer by adjusting and scaling
+ *          the outputs.
+ *          The operation is defined as:
+ *          \f[
+ *              \text { out }=\frac{x-\mathrm{E}[x]}{\sqrt{\text{Var}[x]+eps}}
+ *          \f]
+ *          `Var` defaults dst->ne[0]. `eps` is in dst->params.
+ *
+ * @param ctx The CANN context used for operations.
+ * @param dst The destination tensor where the normalized values will be stored.
+ * @attention `Var` defaults to dst->ne[0].
+ */
+void ggml_cann_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief  Computes the Group Normalization for a ggml tensor using the CANN
+ *         backend.
+ *
+ * @brief  This function applies the Group Normalization operation on the input
+ *         tensor `src` and stores the result in the destination tensor `dst`.
+ *         Group Normalization divides the channels into groups and normalizes
+ *         the features within each group across spatial locations.
+ *         It is commonly used in convolutional neural networks to improve
+ *         training stability and performance.
+ *         The operation is defined as:
+ *         \f[
+ *             \text { out }=\frac{x-\mathrm{E}[x]}{\sqrt{\text{Var}[x]+eps}}
+ *         \f]
+ *
+ * @param ctx The CANN context used for operations.
+ * @param dst The destination tensor where the normalized values will be stored.
+ *            `n_groups` is in dst->params, which split C channel to `n_groups`.
+ *            dst->op is `GGML_OP_GROUP_NORM`.
+ *
+ * @attention eps defaults to 1e-6f.
+ */
+void ggml_cann_group_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief   Computes the accumulation of tensors using the CANN backend.
+ *
+ * @details This function performs an accumulation operation on two tensors.
+ *          Depending on the `inplace` flag, it either updates the destination
+ *          tensor `dst` in place by adding `alpha * src1` to it, or it creates
+ *          a new tensor as the result of `src0 + alpha * src1` and stores it in
+ *          `dst`.
+ *          The operation is defined as:
+ *          \f[
+ *               dst = src0 + alpha \times src1
+ *          \f]
+ *          if `inplace` is `true`, `src0` is equal to 'dst'.
+ * @param ctx The CANN context used for operations.
+ * @param dst The destination tensor where the accumulated values will be stored.
+ *            `inplace` is in dst->params, and dst->op is `GGML_OP_ACC`.
+ */
+void ggml_cann_acc(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief   Computes the sum of elements along the last dimension of a ggml tensor
+ *          using the CANN backend.
+ *
+ * @details This function performs a reduction sum operation along the last
+ *          dimension of the input tensor `src`. The result of the sum is stored
+ *          in the destination tensor `dst`.
+ *
+ * @param ctx The CANN context used for operations.
+ * @param dst The destination tensor where the reduced values will be stored。
+ *            dst->op is `GGML_OP_SUM_ROWS`.
+ *
+ * @attention `reduce_dims` defaults to 3, which means the last dimension.
+ */
+void ggml_cann_sum_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief   Upsamples a ggml tensor using nearest neighbor interpolation using
+ *          the CANN backend.
+ *
+ * @details This function performs upsampling of the input tensor `src` using
+ *          nearest neighbor interpolation. The upsampling is applied to the
+ *          height and width dimensions (last two dimensions) of the tensor. The
+ *          result is stored in the destination tensor `dst`, which must have
+ *          the appropriate dimensions for the upsampled output.
+ *
+ * @param ctx The CANN context used for operations.
+ * @param dst The destination tensor where the upsampled values will be stored.
+ *            dst->op is `GGML_OP_UPSCALE`.
+ */
+void ggml_cann_upsample_nearest2d(ggml_backend_cann_context& ctx,
+                                  ggml_tensor* dst);
+
+/**
+ * @brief   Pads a ggml tensor to match the dimensions of the destination tensor
+ *          using the CANN backend.
+ *
+ * @details This function pads the input tensor `src` so that it matches the
+ *          dimensions of the destination tensor `dst`. The amount of padding
+ *          is calculated based on the difference in sizes between `src` and
+ *          `dst` along each dimension. The padded tensor is stored in `dst`.
+ *
+ * @param ctx The CANN context used for operations.
+ * @param dst The destination tensor, which specifies the target dimensions for
+ *            padding. dst->op is `GGML_OP_PAD`.
+ */
+void ggml_cann_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief   Executes a 2D pooling operation on a ggml tensor using the CANN
+ *          backend.
+ *
+ * @details This function dispatches the execution of a 2D pooling operation on
+ *          the input tensor `dst`. The type of pooling (average or max) is
+ *          determined by the `op` parameter, which is read from the operation
+ *          parameters of `dst`. The function supports average pooling
+ *          (`GGML_OP_POOL_AVG`) and max pooling (`GGML_OP_POOL_MAX`). If an
+ *          invalid operation is encountered, the function asserts a failure.
+ *
+ * @param ctx The CANN context used for operations.
+ * @param dst The destination tensor on which the pooling operation is to be
+ *            performed. dst->op is `GGML_OP_POOL_2D`.
+ */
+void ggml_cann_pool2d(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief   Duplicates a ggml tensor using the CANN backend.
+ *
+ * @details This function duplicates the contents of the source tensor `src` to
+ *          the destination tensor `dst`. The function supports various tensor
+ *          types and configurations, including handling of extra data, type
+ *          conversions, and special cases for contiguous and non-contiguous
+ *          tensors.
+ *
+ * @param ctx The CANN context used for operations.
+ * @param dst The destination tensor where the duplicated data will be stored.
+ *            dst->op is `GGML_OP_DUP`
+ *
+ * @attention Only support Fp16/FP32. Not support when src and dst have
+ *            different shape and dst is no-contiguous.
+ * @note:     This func need to simplify.
+ */
+void ggml_cann_dup(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief   Computes the Root Mean Square (RMS) normalization of a ggml tensor
+ *          using the CANN backend.
+ *
+ * @details This function applies RMS normalization to the input tensor `src`
+ *          and stores the result in the destination tensor `dst`. RMS
+ *          normalization involves computing the root mean square of the input
+ *          tensor along a specified dimension and then dividing each element of
+ *          the tensor by this value, adjusted by a small epsilon value to
+ *          prevent division by zero.
+ *          The operation is defined as:
+ *          \f[
+ *               \text{RmsNorm}\left(x_i\right)=\frac{x_i}{\text{Rms}(\mathbf{x})} g_i,
+ *               \quad \text { where } \text{Rms}(\mathbf{x})=\sqrt{\frac{1}{n} \sum_{i=1}^n x_i^2+e p s}
+ *          \f]
+ *          `eps` is in dst->op_params.
+ * @param ctx The CANN context used for operations.
+ * @param dst The destination tensor where the normalized values will be stored.
+ *            dst->op is `GGML_OP_RMS_NORM`.
+ */
+void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief   Applies a diagonal mask to the tensor with a specified value.
+ *
+ * @details This function creates a mask tensor filled with ones, then applies
+ *          an upper triangular and lower triangular operation to it based on
+ *          the number of past elements specified. Afterward, it adds the masked
+ *          tensor to the destination tensor in-place.
+ *
+ * @param ctx The backend CANN context used for operations.
+ * @param dst The destination tensor where the result will be stored. dst->op is
+ *            `GGML_OP_DIAG_MASK`
+ * @param value The value to use for masking.
+ */
+void ggml_cann_diag_mask(ggml_backend_cann_context& ctx, ggml_tensor* dst, float value);
+
+/**
+ * @brief   Performs an image-to-column transformation on the input tensor.
+ *
+ * @details This function takes an input tensor and applies an image-to-column
+ *          operation, converting spatial dimensions into column-like
+ *          structures suitable for convolutional operations. It supports both
+ *          half-precision (F16) and single-precision (F32) floating-point data
+ *          types.
+ *
+ * @param ctx The backend CANN context for executing operations.
+ * @param dst The destination tensor that stores the result of the operation.
+ *            dst->op is `GGML_OP_IM2COL`.
+ */
+void ggml_cann_im2col(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief   Computes time step embeddings using sine and cosine functions.
+ *
+ * @details This function calculates time step embeddings by applying sine and
+ *          cosine transformations to a given input tensor, which is typically
+ *          used in temporal models like diffusion models or transformers to
+ *          encode time information effectively.
+ *
+ * @param ctx The backend CANN context for executing operations.
+ * @param dst The destination tensor where the result of the embedding operation
+ *            will be stored. dst->op is `GGML_OP_TIMESTEP_EMBEDDING`.
+ */
+void ggml_cann_timestep_embedding(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+// @see ggml_cann_dup.
+void ggml_cann_cpy(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief   Computes the softmax activation with optional masking.
+ *
+ * @details This function computes the softmax activation over the input tensor,
+ *          optionally applying a mask and scaling factor. It supports both FP16
+ *          and FP32 data types and can handle masking by broadcasting the mask
+ *          across rows if necessary.
+ *          The function performs the following steps:
+ *          1. Multiplies the input tensor by a scale factor.
+ *          2. Optionally casts the mask tensor to FP32 if it is in FP16 format.
+ *          3. Broadcasts the mask tensor if its dimensions do not match the
+ *             input tensor's dimensions.
+ *          4. Adds the mask to the scaled input tensor.
+ *          5. Applies the softmax activation function along the specified
+ *             dimension.
+ *
+ * @param ctx The backend CANN context for executing operations.
+ * @param dst The destination tensor where the result will be stored. dst->op is
+ *            `GGML_OP_SOFTMAX`.
+ */
+void ggml_cann_softmax(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief   Extracts specific rows from a tensor based on indices.
+ *
+ * @details This function retrieves rows from a source tensor src0 according to
+ *          the indices provided in another tensor src1 and stores the result in
+ *          a destination tensor (\p dst). It supports different data types
+ *          including F32, F16, Q4_0, and Q8_0.
+ *
+ * @param ctx The backend CANN context for executing operations.
+ * @param dst The destination tensor where the extracted rows will be stored.
+ *            dst->op is `GGML_OP_GET_ROWS`.
+ */
+void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief   Executes matrix multiplication for the given tensor.
+ *
+ * @details This function performs matrix multiplication on the source tensors
+ *          associated with the destination tensor. It supports matrix
+ *          multiplication F32, F16, and Q8_0.
+ *
+ * @param ctx The backend CANN context for executing operations.
+ * @param dst The destination tensor for storing the result of the matrix
+ *            multiplication. dst->op is `GGML_OP_MUL_MAT`.
+ */
+void ggml_cann_mul_mat(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+/**
+ * @brief Applies Rotary Positional Embedding (RoPE) to the input tensor.
+ *
+ * @details This function implements the RoPE mechanism, which is a method to
+ *          encode positional information into sequence data, particularly
+ *          useful in transformer models. It supports both F32 and F16 data
+ *          types.
+ *
+ * @param ctx The backend CANN context for executing operations.
+ * @param dst The destination tensor where the RoPE-transformed data will be
+ *            stored. dst->op is `GGML_OP_ROPE`.
+ *
+ * @note The function currently does not support cases where the n_dims is less
+ *       than the input tensor's first dimension.
+ * @note The function currently does not support cases where the freq_factors is
+ *       not NULL.
+ * @note The function currently does not support cases where the ext_factor is
+ *       not equal 0.
+ * @note The function currently does not support cases where the freq_scale is
+ *       not equal 1.
+ */
+void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst);
+
+template <aclnnStatus getWorkspaceSize(const aclTensor*, const aclTensor*,
+                                       aclTensor*, uint64_t*, aclOpExecutor**),
+          aclnnStatus execute(void*, uint64_t, aclOpExecutor*, aclrtStream)>
+void ggml_cann_mul_div(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+    ggml_tensor* src0 = dst->src[0];
+    ggml_tensor* src1 = dst->src[1];
+    GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst));
+
+    aclTensor* acl_src0;
+    aclTensor* acl_src1;
+    aclTensor* acl_dst;
+
+    // Need bcast
+    if (!ggml_are_same_shape(src0, src1) && ggml_cann_need_bcast(src0, src1)) {
+        BCAST_SHAPE(src0, src1)
+        acl_src0 = ggml_cann_create_tensor(src0, BCAST_PARAM(src0));
+        acl_src1 = ggml_cann_create_tensor(src1, BCAST_PARAM(src1));
+        acl_dst = ggml_cann_create_tensor(dst, BCAST_PARAM(src0));
+    } else {
+        acl_src0 = ggml_cann_create_tensor(src0);
+        acl_src1 = ggml_cann_create_tensor(src1);
+        acl_dst = ggml_cann_create_tensor(dst);
+    }
+
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    ACL_CHECK(getWorkspaceSize(acl_src0, acl_src1, acl_dst, &workspaceSize,
+                               &executor));
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    aclrtStream main_stream = ctx.stream();
+    ACL_CHECK(execute(workspaceAddr, workspaceSize, executor, main_stream));
+
+    ACL_CHECK(aclDestroyTensor(acl_src0));
+    ACL_CHECK(aclDestroyTensor(acl_src1));
+    ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+// Activation functions template.
+template <aclnnStatus getWorkspaceSize(const aclTensor*, aclTensor*, uint64_t*,
+                                       aclOpExecutor**),
+          aclnnStatus execute(void*, uint64_t, aclOpExecutor*,
+                              const aclrtStream)>
+void ggml_cann_activation(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+    ggml_tensor* src = dst->src[0];
+
+    GGML_ASSERT(src->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+    aclTensor* acl_src = ggml_cann_create_tensor(src);
+    aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    ACL_CHECK(getWorkspaceSize(acl_src, acl_dst, &workspaceSize, &executor));
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    aclrtStream main_stream = ctx.stream();
+    ACL_CHECK(execute(workspaceAddr, workspaceSize, executor, main_stream));
+
+    ACL_CHECK(aclDestroyTensor(acl_src));
+    ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+// Activation functions template for const aclTensors.
+template <aclnnStatus getWorkspaceSize(const aclTensor*, const aclTensor*,
+                                       uint64_t*, aclOpExecutor**),
+          aclnnStatus execute(void*, uint64_t, aclOpExecutor*,
+                              const aclrtStream)>
+void ggml_cann_activation(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
+    ggml_tensor* src = dst->src[0];
+
+    GGML_ASSERT(src->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+    aclTensor* acl_src = ggml_cann_create_tensor(src);
+    aclTensor* acl_dst = ggml_cann_create_tensor(dst);
+
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    ACL_CHECK(getWorkspaceSize(acl_src, acl_dst, &workspaceSize, &executor));
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    aclrtStream main_stream = ctx.stream();
+    ACL_CHECK(execute(workspaceAddr, workspaceSize, executor, main_stream));
+
+    ACL_CHECK(aclDestroyTensor(acl_src));
+    ACL_CHECK(aclDestroyTensor(acl_dst));
+}
+
+#endif  // CANN_ACLNN_OPS
diff --git a/src/ggml-cann/common.h b/src/ggml-cann/common.h
new file mode 100644 (file)
index 0000000..e6a5701
--- /dev/null
@@ -0,0 +1,282 @@
+/*
+ * Copyright (c) 2023-2024 The ggml authors
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in
+ * all copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+ * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
+ * IN THE SOFTWARE.
+ */
+
+#ifndef CANN_COMMON_H
+#define CANN_COMMON_H
+
+#include <acl/acl.h>
+
+#include <cstdio>
+#include <iostream>
+#include <map>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "../include/ggml-cann.h"
+#include "../include/ggml.h"
+
+#define MATRIX_ROW_PADDING 512
+#define GGML_CANN_MAX_STREAMS 8
+
+/**
+ * @brief Handles CANN-related errors by printing an error message and
+ *        terminating the program.
+ * @param stmt The statement that caused the error.
+ * @param func The function in which the error occurred.
+ * @param file The file in which the error occurred.
+ * @param line The line number at which the error occurred.
+ * @param msg The error message.
+ */
+[[noreturn]] void ggml_cann_error(const char* stmt, const char* func,
+                                  const char* file, int line, const char* msg);
+
+/**
+ * @brief Checks the result of a CANN function call and invokes the error
+ *        handler if the call fails.
+ * @param stmt The CANN function call to check.
+ * @param success The success code that indicates the call was successful.
+ * @param error_fn The function to call to retrieve the error message.
+ */
+#define ACL_CHECK_GEN(stmt, success, error_fn)                                \
+    do {                                                                      \
+        int err_code = (stmt);                                                \
+        if (err_code != (success)) {                                          \
+            ggml_cann_error(#stmt, __func__, __FILE__, __LINE__, error_fn()); \
+        }                                                                     \
+    } while (0);
+
+#define ACL_CHECK(stmt) ACL_CHECK_GEN(stmt, 0, aclGetRecentErrMsg)
+
+/**
+ * @brief Contains information about CANN devices.
+ */
+struct ggml_cann_device_info {
+    /**
+     * @brief Number of CANN devices available.
+     */
+    int32_t device_count;
+
+    /**
+     * @brief Information about a single CANN device.
+     */
+    struct cann_device_info {
+        int cc;                 /**< Compute capability.                   */
+        size_t smpb;            /**< Maximum shared memory per block.      */
+        bool vmm;               /**< Virtual memory support.               */
+        size_t vmm_granularity; /**< Granularity of virtual memory.        */
+        size_t total_vram;      /**< Total video RAM available on the device. */
+    };
+
+    cann_device_info devices[GGML_CANN_MAX_DEVICES] =
+        {}; /**< Array of CANN device information. */
+};
+
+const ggml_cann_device_info& ggml_cann_info();
+
+void ggml_cann_set_device(int32_t device);
+int32_t ggml_cann_get_device();
+
+/**
+ * @brief Abstract base class for memory pools used by CANN.
+ */
+struct ggml_cann_pool {
+    /**
+     * @brief Virtual destructor for the memory pool.
+     */
+    virtual ~ggml_cann_pool() = default;
+
+    /**
+     * @brief Allocates memory from the pool.
+     *
+     * @param size         The size of the memory block to allocate.
+     * @param actual_size  Pointer to a variable where the actual allocated size
+     *                     will be stored.
+     * @return             Pointer to the allocated memory block.
+     */
+    virtual void* alloc(size_t size, size_t* actual_size) = 0;
+
+    /**
+     * @brief Frees a previously allocated memory block.
+     *
+     * @param ptr   Pointer to the memory block to free.
+     * @param size  Size of the memory block to free.
+     * @note Note that all CANN opertors are running async. Make sure memory is
+     *       still avaiable before this operator finished.
+     */
+    virtual void free(void* ptr, size_t size) = 0;
+};
+
+/**
+ * @brief RAII wrapper for managing memory allocations from a CANN memory pool.
+ */
+struct ggml_cann_pool_alloc {
+    ggml_cann_pool* pool = nullptr; /**< Pointer to the memory pool. */
+    void* ptr = nullptr;    /**< Pointer to the allocated memory block. */
+    size_t actual_size = 0; /**< Actual size of the allocated memory block. */
+
+    /**
+     * @brief Default constructor.
+     */
+    ggml_cann_pool_alloc() = default;
+
+    /**
+     * @brief Constructor that initializes the memory pool.
+     * @param pool Reference to the memory pool.
+     */
+    explicit ggml_cann_pool_alloc(ggml_cann_pool& pool) : pool(&pool) {}
+
+    /**
+     * @brief Constructor that initializes the memory pool and allocates memory.
+     * @param pool Reference to the memory pool.
+     * @param size Size of the memory block to allocate.
+     */
+    ggml_cann_pool_alloc(ggml_cann_pool& pool, size_t size) : pool(&pool) {
+        alloc(size);
+    }
+
+    /**
+     * @brief Destructor that frees the allocated memory block.
+     */
+    ~ggml_cann_pool_alloc() {
+        if (ptr != nullptr) {
+            pool->free(ptr, actual_size);
+        }
+    }
+
+    /**
+     * @brief Allocates memory from the pool.
+     * @param size Size of the memory block to allocate.
+     * @return Pointer to the allocated memory block.
+     */
+    void* alloc(size_t size) {
+        GGML_ASSERT(pool != nullptr);
+        GGML_ASSERT(ptr == nullptr);
+        ptr = pool->alloc(size, &this->actual_size);
+        return ptr;
+    }
+
+    /**
+     * @brief Allocates memory from a specific memory pool.
+     * @param pool Reference to the memory pool.
+     * @param size Size of the memory block to allocate.
+     * @return Pointer to the allocated memory block.
+     */
+    void* alloc(ggml_cann_pool& pool, size_t size) {
+        this->pool = &pool;
+        return alloc(size);
+    }
+
+    /**
+     * @brief Gets the pointer to the allocated memory block.
+     * @return Pointer to the allocated memory block.
+     */
+    void* get() { return ptr; }
+
+    // Deleted copy constructor
+    ggml_cann_pool_alloc(const ggml_cann_pool_alloc&) = delete;
+
+    // Deleted move constructor
+    ggml_cann_pool_alloc(ggml_cann_pool_alloc&&) = delete;
+
+    // Deleted copy assignment operator
+    ggml_cann_pool_alloc& operator=(const ggml_cann_pool_alloc&) = delete;
+
+    // Deleted move assignment operator
+    ggml_cann_pool_alloc& operator=(ggml_cann_pool_alloc&&) = delete;
+};
+
+/**
+ * @brief Context for managing CANN backend operations.
+ */
+struct ggml_backend_cann_context {
+    int32_t device;                  /**< Device ID. */
+    std::string name;                /**< Name of the device. */
+    aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
+
+    aclrtStream streams[GGML_CANN_MAX_STREAMS] = {
+        {nullptr}}; /**< Array of streams for the device. */
+
+    /**
+     * @brief Constructor for initializing the context with a given device.
+     * @param device Device ID.
+     */
+    explicit ggml_backend_cann_context(int device)
+        : device(device), name("CANN" + std::to_string(device)) {}
+
+    /**
+     * @brief Destructor for cleaning up resources.
+     */
+    ~ggml_backend_cann_context() {
+        if (copy_event != nullptr) {
+            ACL_CHECK(aclrtDestroyEvent(copy_event));
+        }
+        for (int i = 0; i < GGML_CANN_MAX_STREAMS; ++i) {
+            if (streams[i] != nullptr) {
+                ACL_CHECK(aclrtDestroyStream(streams[i]));
+            }
+        }
+    }
+
+    /**
+     * @brief Get or create a stream for a given index.
+     * @param stream Index of the stream.
+     * @return The stream corresponding to the given index.
+     */
+    aclrtStream stream(int stream) {
+        if (streams[stream] == nullptr) {
+            ggml_cann_set_device(device);
+            ACL_CHECK(aclrtCreateStream(&streams[stream]));
+        }
+        return streams[stream];
+    }
+
+    /**
+     * @brief Get or create the default stream (index 0).
+     * @return The default stream.
+     */
+    aclrtStream stream() { return stream(0); }
+
+    // TODO: each stream should have a memory pool.
+    std::unique_ptr<ggml_cann_pool>
+        mem_pool; /**< Memory pool for the device. */
+
+    /**
+     * @brief Create a new memory pool for a given device.
+     * @param device Device ID.
+     * @return A unique pointer to the new memory pool.
+     */
+    static std::unique_ptr<ggml_cann_pool> new_pool_for_device(int device);
+
+    /**
+     * @brief Get or create the memory pool for the context.
+     * @return Reference to the memory pool.
+     */
+    ggml_cann_pool& pool() {
+        if (mem_pool == nullptr) {
+            mem_pool = new_pool_for_device(device);
+        }
+        return *mem_pool;
+    }
+};
+
+#endif  // CANN_COMMON_H
diff --git a/src/ggml-cann/kernels/CMakeLists.txt b/src/ggml-cann/kernels/CMakeLists.txt
new file mode 100644 (file)
index 0000000..5b4fef9
--- /dev/null
@@ -0,0 +1,33 @@
+if (NOT SOC_TYPE)
+    set (SOC_TYPE "Ascend910B3")
+endif()
+
+file(GLOB SRC_FILES
+    get_row_f32.cpp
+    get_row_f16.cpp
+    get_row_q4_0.cpp
+    get_row_q8_0.cpp
+    quantize_f32_q8_0.cpp
+    quantize_f16_q8_0.cpp
+    quantize_float_to_q4_0.cpp
+    dup.cpp
+)
+
+string(TOLOWER ${SOC_TYPE} SOC_VERSION)
+set(ASCEND_CANN_PACKAGE_PATH ${CANN_INSTALL_DIR})
+set(RUN_MODE "npu" CACHE STRING "run mode: npu/sim")
+
+if(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/compiler/tikcpp/ascendc_kernel_cmake)
+    set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/compiler/tikcpp/ascendc_kernel_cmake)
+elseif(EXISTS ${ASCEND_CANN_PACKAGE_PATH}/ascendc_devkit/tikcpp/samples/cmake)
+    set(ASCENDC_CMAKE_DIR ${ASCEND_CANN_PACKAGE_PATH}/ascendc_devkit/tikcpp/samples/cmake)
+else()
+    message(FATAL_ERROR "ascendc_kernel_cmake does not exist, please check whether the compiler package is installed.")
+endif()
+include(${ASCENDC_CMAKE_DIR}/ascendc.cmake)
+
+ascendc_library(ascendc_kernels STATIC
+    ${SRC_FILES}
+)
+
+# ascendc_compile_definitions(ascendc_kernels PRIVATE -DASCENDC_DUMP)
diff --git a/src/ggml-cann/kernels/ascendc_kernels.h b/src/ggml-cann/kernels/ascendc_kernels.h
new file mode 100644 (file)
index 0000000..7e15320
--- /dev/null
@@ -0,0 +1,19 @@
+#ifndef ASCENDC_KERNELS_H
+#define ASCENDC_KERNELS_H
+
+#include "aclrtlaunch_ascendc_get_row_f32.h"
+#include "aclrtlaunch_ascendc_get_row_f16.h"
+#include "aclrtlaunch_ascendc_get_row_q8_0.h"
+#include "aclrtlaunch_ascendc_get_row_q4_0.h"
+
+#include "aclrtlaunch_ascendc_quantize_f32_q8_0.h"
+#include "aclrtlaunch_ascendc_quantize_f16_q8_0.h"
+#include "aclrtlaunch_ascendc_quantize_f16_to_q4_0.h"
+#include "aclrtlaunch_ascendc_quantize_f32_to_q4_0.h"
+
+#include "aclrtlaunch_ascendc_dup_by_rows_fp16.h"
+#include "aclrtlaunch_ascendc_dup_by_rows_fp32.h"
+#include "aclrtlaunch_ascendc_dup_by_rows_fp32_to_fp16.h"
+#include "aclrtlaunch_ascendc_dup_by_rows_fp16_to_fp32.h"
+
+#endif  // ASCENDC_KERNELS_H
diff --git a/src/ggml-cann/kernels/dup.cpp b/src/ggml-cann/kernels/dup.cpp
new file mode 100644 (file)
index 0000000..e2c6511
--- /dev/null
@@ -0,0 +1,223 @@
+#include "kernel_operator.h"
+
+#include <cmath>
+
+using namespace AscendC;
+
+#define BUFFER_NUM 2
+
+template <typename SRC_T, typename DST_T>
+class DupByRows {
+   public:
+    __aicore__ inline DupByRows() {}
+    __aicore__ inline void init(GM_ADDR src, GM_ADDR dst, int64_t *input_ne_ub,
+                                size_t *input_nb_ub) {
+        /* Dup by rows when src is contigous on first dimension and dst is
+        contiguous, each kernel process one row.
+        */
+
+        // Input has four dims.
+        int64_t op_block_num = GetBlockNum();
+        int64_t op_block_idx = GetBlockIdx();
+
+        // param
+        num_rows = input_ne_ub[1] * input_ne_ub[2] * input_ne_ub[3];
+        num_elem = input_ne_ub[0];
+
+        // index for (ne[1], ne[2], ne[3]): (idx_ne1, idx_ne2, idx_ne3)
+        idx_ne3 = op_block_idx / (input_ne_ub[1] * input_ne_ub[2]);
+        idx_ne2 = (op_block_idx - idx_ne3 * (input_ne_ub[1] * input_ne_ub[2]))
+                  / (input_ne_ub[1]);
+        idx_ne1 = op_block_idx - idx_ne3 * (input_ne_ub[1] * input_ne_ub[2])
+                - idx_ne2 * input_ne_ub[1];
+
+        // src may not contiguous in dim [1,2,3], so stride decited by ne&nb
+        src_stride = input_nb_ub[3] * idx_ne3 + input_nb_ub[2] * idx_ne2
+                     + input_nb_ub[1] * idx_ne1;
+
+        // dst is contiguous
+        dst_stride = op_block_idx * (input_ne_ub[0] * sizeof(DST_T));
+
+        src_gm.SetGlobalBuffer(reinterpret_cast<__gm__ SRC_T *>(src +
+                                                                src_stride));
+        dst_gm.SetGlobalBuffer(reinterpret_cast<__gm__ DST_T *>(dst +
+                                                                dst_stride));
+
+        pipe.InitBuffer(src_queue, BUFFER_NUM, (sizeof(SRC_T) * num_elem +
+                                                32 - 1) / 32 * 32);
+        pipe.InitBuffer(dst_queue, BUFFER_NUM, (sizeof(DST_T) * num_elem +
+                                                32 - 1) / 32 * 32);
+    }
+
+    __aicore__ inline void copy_in() {
+        LocalTensor<SRC_T> src_local = src_queue.AllocTensor<SRC_T>();
+
+        DataCopyExtParams dataCopyParams;
+        dataCopyParams.blockCount = 1;
+        dataCopyParams.blockLen = num_elem * sizeof(SRC_T);
+        DataCopyPadExtParams<SRC_T> padParams;
+        DataCopyPad(src_local, src_gm, dataCopyParams, padParams);
+
+        src_queue.EnQue(src_local);
+    }
+
+    __aicore__ inline void copy_out() {
+        LocalTensor<DST_T> dst_local = dst_queue.DeQue<DST_T>();
+
+        DataCopyExtParams dataCopyParams;
+        dataCopyParams.blockCount = 1;
+        dataCopyParams.blockLen = num_elem * sizeof(DST_T);
+        DataCopyPad(dst_gm, dst_local, dataCopyParams);
+
+        dst_queue.FreeTensor(dst_local);
+    }
+
+    __aicore__ inline void dup() {
+        // main process, copy one row data from src to dst.
+        copy_in();
+
+        LocalTensor<SRC_T> src_local = src_queue.DeQue<SRC_T>();
+        LocalTensor<DST_T> dst_local = dst_queue.AllocTensor<DST_T>();
+
+        int32_t BLOCK_NUM = 32 / sizeof(DST_T);
+        DataCopy(dst_local, src_local, (num_elem + BLOCK_NUM - 1)
+                                        / BLOCK_NUM * BLOCK_NUM);
+        dst_queue.EnQue<DST_T>(dst_local);
+
+        src_queue.FreeTensor(src_local);
+        copy_out();
+    }
+
+    __aicore__ inline void dup_with_cast() {
+        // main process, copy one row data from src to dst.
+        // cast dtype from src to dst.
+        copy_in();
+
+        LocalTensor<SRC_T> src_local = src_queue.DeQue<SRC_T>();
+        LocalTensor<DST_T> dst_local = dst_queue.AllocTensor<DST_T>();
+
+        Cast(dst_local, src_local, RoundMode::CAST_NONE, num_elem);
+        dst_queue.EnQue<DST_T>(dst_local);
+
+        src_queue.FreeTensor(src_local);
+        copy_out();
+    }
+
+   private:
+
+    TPipe pipe;
+    GlobalTensor<SRC_T> src_gm;
+    GlobalTensor<DST_T> dst_gm;
+
+    int64_t num_rows;
+    int64_t num_elem;
+    int64_t idx_ne3;
+    int64_t idx_ne2;
+    int64_t idx_ne1;
+    int64_t src_stride;
+    int64_t dst_stride;
+
+    TQue<QuePosition::VECIN, BUFFER_NUM> src_queue;
+    TQue<QuePosition::VECOUT, BUFFER_NUM> dst_queue;
+};
+
+template <typename T>
+__aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) {
+    auto gm_ptr = (__gm__ uint8_t *)gm;
+    auto ub_ptr = (uint8_t *)(ub);
+    for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) {
+        *ub_ptr = *gm_ptr;
+    }
+}
+
+extern "C" __global__ __aicore__ void ascendc_dup_by_rows_fp16(
+                                                        GM_ADDR src_gm,
+                                                        GM_ADDR dst_gm,
+                                                        GM_ADDR input_ne_gm,
+                                                        GM_ADDR input_nb_gm,
+                                                        GM_ADDR output_ne_gm,
+                                                        GM_ADDR output_nb_gm) {
+
+    int64_t input_ne_ub[4];
+    size_t input_nb_ub[4];
+    int64_t output_ne_ub[4];
+    size_t output_nb_ub[4];
+
+    copy_to_ub(input_ne_gm, input_ne_ub, 32);
+    copy_to_ub(input_nb_gm, input_nb_ub, 32);
+    copy_to_ub(output_ne_gm, output_ne_ub, 32);
+    copy_to_ub(output_nb_gm, output_nb_ub, 32);
+
+    DupByRows<half, half> op;
+    op.init(src_gm, dst_gm, input_ne_ub, input_nb_ub);
+    op.dup();
+}
+
+extern "C" __global__ __aicore__ void ascendc_dup_by_rows_fp32(
+                                                        GM_ADDR src_gm,
+                                                        GM_ADDR dst_gm,
+                                                        GM_ADDR input_ne_gm,
+                                                        GM_ADDR input_nb_gm,
+                                                        GM_ADDR output_ne_gm,
+                                                        GM_ADDR output_nb_gm) {
+    int64_t input_ne_ub[4];
+    size_t input_nb_ub[4];
+    int64_t output_ne_ub[4];
+    size_t output_nb_ub[4];
+
+    copy_to_ub(input_ne_gm, input_ne_ub, 32);
+    copy_to_ub(input_nb_gm, input_nb_ub, 32);
+    copy_to_ub(output_ne_gm, output_ne_ub, 32);
+    copy_to_ub(output_nb_gm, output_nb_ub, 32);
+
+    DupByRows<float_t, float_t> op;
+    op.init(src_gm, dst_gm, input_ne_ub, input_nb_ub);
+    op.dup();
+}
+
+extern "C" __global__ __aicore__ void ascendc_dup_by_rows_fp32_to_fp16(
+                                                        GM_ADDR src_gm,
+                                                        GM_ADDR dst_gm,
+                                                        GM_ADDR input_ne_gm,
+                                                        GM_ADDR input_nb_gm,
+                                                        GM_ADDR output_ne_gm,
+                                                        GM_ADDR output_nb_gm) {
+
+    int64_t input_ne_ub[4];
+    size_t input_nb_ub[4];
+    int64_t output_ne_ub[4];
+    size_t output_nb_ub[4];
+
+    copy_to_ub(input_ne_gm, input_ne_ub, 32);
+    copy_to_ub(input_nb_gm, input_nb_ub, 32);
+    copy_to_ub(output_ne_gm, output_ne_ub, 32);
+    copy_to_ub(output_nb_gm, output_nb_ub, 32);
+
+    DupByRows<float_t, half> op;
+    op.init(src_gm, dst_gm, input_ne_ub, input_nb_ub);
+    op.dup_with_cast();
+}
+
+extern "C" __global__ __aicore__ void ascendc_dup_by_rows_fp16_to_fp32(
+                                                        GM_ADDR src_gm,
+                                                        GM_ADDR dst_gm,
+                                                        GM_ADDR input_ne_gm,
+                                                        GM_ADDR input_nb_gm,
+                                                        GM_ADDR output_ne_gm,
+                                                        GM_ADDR output_nb_gm) {
+
+    // copy params from gm to ub.
+    int64_t input_ne_ub[4];
+    size_t input_nb_ub[4];
+    int64_t output_ne_ub[4];
+    size_t output_nb_ub[4];
+
+    copy_to_ub(input_ne_gm, input_ne_ub, 32);
+    copy_to_ub(input_nb_gm, input_nb_ub, 32);
+    copy_to_ub(output_ne_gm, output_ne_ub, 32);
+    copy_to_ub(output_nb_gm, output_nb_ub, 32);
+
+    DupByRows<half, float_t> op;
+    op.init(src_gm, dst_gm, input_ne_ub, input_nb_ub);
+    op.dup_with_cast();
+}
diff --git a/src/ggml-cann/kernels/get_row_f16.cpp b/src/ggml-cann/kernels/get_row_f16.cpp
new file mode 100644 (file)
index 0000000..c704b5b
--- /dev/null
@@ -0,0 +1,186 @@
+#include "kernel_operator.h"
+
+// optimize me. Use template to avoid copy code.
+using namespace AscendC;
+
+#define BUFFER_NUM 2
+
+class GET_ROW_F16 {
+   public:
+    __aicore__ inline GET_ROW_F16() {}
+    __aicore__ inline void init(GM_ADDR input, GM_ADDR indices, GM_ADDR output,
+                                int64_t *input_ne_ub, size_t *input_nb_ub,
+                                int64_t *indices_ne_ub, size_t *indices_nb_ub,
+                                int64_t *output_ne_ub, size_t *output_nb_ub) {
+        // TODO, use template for F16/f32
+        int64_t op_block_num = GetBlockNum();
+        int64_t op_block_idx = GetBlockIdx();
+
+        for (int i = 0; i < 4; i++) {
+            input_ne[i] = input_ne_ub[i];
+            input_stride[i] = input_nb_ub[i] / input_nb_ub[0];
+
+            indices_ne[i] = indices_ne_ub[i];
+            indices_stride[i] = indices_nb_ub[i] / indices_nb_ub[0];
+
+            output_ne[i] = output_ne_ub[i];
+            output_stride[i] = output_nb_ub[i] / output_nb_ub[0];
+        }
+
+        // Indices has two dims. n_elements = all rows should get.
+        // dr, all rows should this thread get.
+        uint64_t n_elements =
+            indices_ne[0] * indices_ne[1] * indices_ne[2] * indices_ne[3];
+        dr = n_elements / op_block_num;
+
+        uint64_t tails = n_elements % op_block_num;
+        if (op_block_idx < tails) {
+            dr += 1;
+            ir = dr * op_block_idx;
+        } else {
+            ir = dr * op_block_idx + tails;
+        }
+
+        input_gm.SetGlobalBuffer((__gm__ half *)input);
+        indices_gm.SetGlobalBuffer((__gm__ int32_t *)indices);
+        output_gm.SetGlobalBuffer((__gm__ float *)output);
+
+        uint64_t input_local_buffer_size = ((input_ne[0] * sizeof(half) + 31)
+                                             & ~31);
+        uint64_t output_local_buffer_size = ((input_ne[0] * sizeof(float) + 31)
+                                              & ~31);
+
+        local_buffer_elems = input_local_buffer_size / sizeof(half);
+
+        // TODO, consider long row that can't put in UB.
+        // All data should asign to 32. It's ok because all data is align to 32.
+        pipe.InitBuffer(input_queue, BUFFER_NUM, input_local_buffer_size);
+        pipe.InitBuffer(output_queue, BUFFER_NUM, output_local_buffer_size);
+    }
+
+    __aicore__ inline void copy_in(uint32_t offset, size_t len) {
+        LocalTensor<half> input_local = input_queue.AllocTensor<half>();
+        size_t tail = len % 32;
+        len = len & ~31;
+        DataCopy(input_local, input_gm[offset], len);
+        if(tail != 0) {
+            DataCopyExtParams dataCopyParams;
+            dataCopyParams.blockCount = 1;
+            dataCopyParams.blockLen = tail * sizeof(half);
+            DataCopyPadExtParams<half> padParams;
+            DataCopyPad(input_local[len], input_gm[offset + len],
+                        dataCopyParams, padParams);
+        }
+        input_queue.EnQue(input_local);
+    }
+
+    __aicore__ inline void copy_out(uint32_t offset, size_t len) {
+        LocalTensor<float> output_local = output_queue.DeQue<float>();
+        size_t tail = len % 32;
+        len = len & ~31;
+        DataCopy(output_gm[offset], output_local, len);
+        if(tail != 0) {
+            DataCopyExtParams dataCopyParams;
+            dataCopyParams.blockCount = 1;
+            dataCopyParams.blockLen = tail * sizeof(float);
+            DataCopyPad(output_gm[offset + len], output_local[len],
+                        dataCopyParams);
+        }
+        output_queue.FreeTensor(output_local);
+    }
+
+    __aicore__ inline void calculate_row(int64_t idx) {
+        const int64_t indices_ne2_idx = idx / (indices_ne[0] * indices_ne[1]);
+        const int64_t indices_ne1_idx =
+            (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1]) /
+            indices_ne[0];
+        const int64_t indices_ne0_idx =
+            (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1] -
+             indices_ne1_idx * indices_ne[0]);
+
+        const int64_t indices_offset = indices_ne0_idx * indices_stride[0] +
+                                       indices_ne1_idx * indices_stride[1] +
+                                       indices_ne2_idx * indices_stride[2];
+        const int32_t selected_row_idx = indices_gm.GetValue(indices_offset);
+
+        const int64_t input_offset = selected_row_idx * input_stride[1] +
+                                     indices_ne1_idx * input_stride[2] +
+                                     indices_ne2_idx * input_stride[3];
+
+        const int64_t output_offset = indices_ne0_idx * output_stride[1] +
+                                      indices_ne1_idx * output_stride[2] +
+                                      indices_ne2_idx * output_stride[3];
+
+        copy_in(input_offset, input_ne[0]);
+        LocalTensor<half> input_local = input_queue.DeQue<half>();
+        LocalTensor<float> output_local = output_queue.AllocTensor<float>();
+
+        Cast(output_local, input_local, RoundMode::CAST_NONE,
+             local_buffer_elems);
+        output_queue.EnQue(output_local);
+        copy_out(output_offset, input_ne[0]);
+
+        input_queue.FreeTensor(input_local);
+    }
+
+    __aicore__ inline void calculate() {
+        for (int64_t i = ir; i < ir + dr; i++) {
+            calculate_row(i);
+        }
+    }
+
+   private:
+    int64_t input_ne[4];
+    size_t input_stride[4];
+
+    int64_t indices_ne[4];
+    size_t indices_stride[4];
+
+    int64_t output_ne[4];
+    size_t output_stride[4];
+
+    size_t local_buffer_elems;
+
+    int64_t ir;
+    int64_t dr;
+
+    TPipe pipe;
+    GlobalTensor<half> input_gm;
+    GlobalTensor<int32_t> indices_gm;
+    GlobalTensor<float> output_gm;
+    TQue<QuePosition::VECIN, BUFFER_NUM> input_queue;
+    TQue<QuePosition::VECOUT, BUFFER_NUM> output_queue;
+};
+
+template <typename T>
+__aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) {
+    auto gm_ptr = (__gm__ uint8_t *)gm;
+    auto ub_ptr = (uint8_t *)(ub);
+    for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) {
+        *ub_ptr = *gm_ptr;
+    }
+}
+
+extern "C" __global__ __aicore__ void ascendc_get_row_f16(
+    GM_ADDR input_gm, GM_ADDR indices_gm, GM_ADDR output_gm,
+    GM_ADDR input_ne_gm, GM_ADDR input_nb_gm, GM_ADDR indices_ne_gm,
+    GM_ADDR indices_nb_gm, GM_ADDR output_ne_gm, GM_ADDR output_nb_gm) {
+    int64_t input_ne_ub[4];
+    size_t input_nb_ub[4];
+    int64_t indices_ne_ub[4];
+    size_t indices_nb_ub[4];
+    int64_t output_ne_ub[4];
+    size_t output_nb_ub[4];
+
+    copy_to_ub(input_ne_gm, input_ne_ub, 32);
+    copy_to_ub(input_nb_gm, input_nb_ub, 32);
+    copy_to_ub(indices_ne_gm, indices_ne_ub, 32);
+    copy_to_ub(indices_nb_gm, indices_nb_ub, 32);
+    copy_to_ub(output_ne_gm, output_ne_ub, 32);
+    copy_to_ub(output_nb_gm, output_nb_ub, 32);
+
+    GET_ROW_F16 op;
+    op.init(input_gm, indices_gm, output_gm, input_ne_ub, input_nb_ub,
+            indices_ne_ub, indices_nb_ub, output_ne_ub, output_nb_ub);
+    op.calculate();
+}
diff --git a/src/ggml-cann/kernels/get_row_f32.cpp b/src/ggml-cann/kernels/get_row_f32.cpp
new file mode 100644 (file)
index 0000000..9db080a
--- /dev/null
@@ -0,0 +1,180 @@
+#include "kernel_operator.h"
+
+// optimize me. Use template to avoid copy code.
+using namespace AscendC;
+
+#define BUFFER_NUM 2
+
+class GET_ROW_F32 {
+   public:
+    __aicore__ inline GET_ROW_F32() {}
+    __aicore__ inline void init(GM_ADDR input, GM_ADDR indices, GM_ADDR output,
+                                int64_t *input_ne_ub, size_t *input_nb_ub,
+                                int64_t *indices_ne_ub, size_t *indices_nb_ub,
+                                int64_t *output_ne_ub, size_t *output_nb_ub) {
+        int64_t op_block_num = GetBlockNum();
+        int64_t op_block_idx = GetBlockIdx();
+
+        for (int i = 0; i < 4; i++) {
+            input_ne[i] = input_ne_ub[i];
+            input_stride[i] = input_nb_ub[i] / input_nb_ub[0];
+
+            indices_ne[i] = indices_ne_ub[i];
+            indices_stride[i] = indices_nb_ub[i] / indices_nb_ub[0];
+
+            output_ne[i] = output_ne_ub[i];
+            output_stride[i] = output_nb_ub[i] / output_nb_ub[0];
+        }
+
+        // Indices has two dims. n_elements = all rows should get.
+        // dr, all rows should this thread get.
+        uint64_t n_elements =
+            indices_ne[0] * indices_ne[1] * indices_ne[2] * indices_ne[3];
+        dr = n_elements / op_block_num;
+
+        uint64_t tails = n_elements % op_block_num;
+        if (op_block_idx < tails) {
+            dr += 1;
+            ir = dr * op_block_idx;
+        } else {
+            ir = dr * op_block_idx + tails;
+        }
+
+        input_gm.SetGlobalBuffer((__gm__ float *)input);
+        indices_gm.SetGlobalBuffer((__gm__ int32_t *)indices);
+        output_gm.SetGlobalBuffer((__gm__ float *)output);
+
+        uint64_t local_buffer_size = ((input_ne[0] * sizeof(float) + 31) & ~31);
+        local_buffer_elems = local_buffer_size / sizeof(float);
+
+        // TODO, consider long row that can't put in UB.
+        // All data should asign to 32. It's ok because all data is align to 32.
+        pipe.InitBuffer(input_queue, BUFFER_NUM, local_buffer_size);
+        pipe.InitBuffer(output_queue, BUFFER_NUM, local_buffer_size);
+    }
+
+    __aicore__ inline void copy_in(uint32_t offset, size_t len) {
+        LocalTensor<float> input_local = input_queue.AllocTensor<float>();
+        size_t tail = len % 32;
+        len = len & ~31;
+        DataCopy(input_local, input_gm[offset], len);
+        if(tail != 0) {
+            DataCopyExtParams dataCopyParams;
+            dataCopyParams.blockCount = 1;
+            dataCopyParams.blockLen = tail * sizeof(float);
+            DataCopyPadExtParams<float> padParams;
+            DataCopyPad(input_local[len], input_gm[offset + len],
+                        dataCopyParams, padParams);
+        }
+        input_queue.EnQue(input_local);
+    }
+
+    __aicore__ inline void copy_out(uint32_t offset, size_t len) {
+        LocalTensor<float> output_local = output_queue.DeQue<float>();
+        size_t tail = len % 32;
+        len = len & ~31;
+        DataCopy(output_gm[offset], output_local, len);
+        if(tail != 0) {
+            DataCopyExtParams dataCopyParams;
+            dataCopyParams.blockCount = 1;
+            dataCopyParams.blockLen = tail * sizeof(float);
+            DataCopyPad(output_gm[offset + len], output_local[len],
+                        dataCopyParams);
+        }
+        output_queue.FreeTensor(output_local);
+    }
+
+    __aicore__ inline void calculate_row(int64_t idx) {
+        const int64_t indices_ne2_idx = idx / (indices_ne[0] * indices_ne[1]);
+        const int64_t indices_ne1_idx =
+            (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1]) /
+            indices_ne[0];
+        const int64_t indices_ne0_idx =
+            (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1] -
+             indices_ne1_idx * indices_ne[0]);
+
+        const int64_t indices_offset = indices_ne0_idx * indices_stride[0] +
+                                       indices_ne1_idx * indices_stride[1] +
+                                       indices_ne2_idx * indices_stride[2];
+        const int32_t selected_row_idx = indices_gm.GetValue(indices_offset);
+
+        const int64_t input_offset = selected_row_idx * input_stride[1] +
+                                     indices_ne1_idx * input_stride[2] +
+                                     indices_ne2_idx * input_stride[3];
+
+        const int64_t output_offset = indices_ne0_idx * output_stride[1] +
+                                      indices_ne1_idx * output_stride[2] +
+                                      indices_ne2_idx * output_stride[3];
+
+        copy_in(input_offset, input_ne[0]);
+        LocalTensor<float> input_local = input_queue.DeQue<float>();
+        LocalTensor<float> output_local = output_queue.AllocTensor<float>();
+
+        DataCopy(output_local, input_local, local_buffer_elems);
+        output_queue.EnQue(output_local);
+        copy_out(output_offset, input_ne[0]);
+
+        input_queue.FreeTensor(input_local);
+    }
+
+    __aicore__ inline void calculate() {
+        for (int64_t i = ir; i < ir + dr; i++) {
+            calculate_row(i);
+        }
+    }
+
+   private:
+    int64_t input_ne[4];
+    size_t input_stride[4];
+
+    int64_t indices_ne[4];
+    size_t indices_stride[4];
+
+    int64_t output_ne[4];
+    size_t output_stride[4];
+
+    size_t local_buffer_elems;
+
+    int64_t ir;
+    int64_t dr;
+
+    TPipe pipe;
+    GlobalTensor<float> input_gm;
+    GlobalTensor<int32_t> indices_gm;
+    GlobalTensor<float> output_gm;
+    TQue<QuePosition::VECIN, BUFFER_NUM> input_queue;
+    TQue<QuePosition::VECOUT, BUFFER_NUM> output_queue;
+};
+
+template <typename T>
+__aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) {
+    auto gm_ptr = (__gm__ uint8_t *)gm;
+    auto ub_ptr = (uint8_t *)(ub);
+    for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) {
+        *ub_ptr = *gm_ptr;
+    }
+}
+
+extern "C" __global__ __aicore__ void ascendc_get_row_f32(
+    GM_ADDR input_gm, GM_ADDR indices_gm, GM_ADDR output_gm,
+    GM_ADDR input_ne_gm, GM_ADDR input_nb_gm, GM_ADDR indices_ne_gm,
+    GM_ADDR indices_nb_gm, GM_ADDR output_ne_gm, GM_ADDR output_nb_gm) {
+    int64_t input_ne_ub[4];
+    size_t input_nb_ub[4];
+    int64_t indices_ne_ub[4];
+    size_t indices_nb_ub[4];
+    int64_t output_ne_ub[4];
+    size_t output_nb_ub[4];
+
+    copy_to_ub(input_ne_gm, input_ne_ub, 32);
+    copy_to_ub(input_nb_gm, input_nb_ub, 32);
+    copy_to_ub(indices_ne_gm, indices_ne_ub, 32);
+    copy_to_ub(indices_nb_gm, indices_nb_ub, 32);
+    copy_to_ub(output_ne_gm, output_ne_ub, 32);
+    copy_to_ub(output_nb_gm, output_nb_ub, 32);
+
+    GET_ROW_F32 op;
+    op.init(input_gm, indices_gm, output_gm, input_ne_ub, input_nb_ub,
+            indices_ne_ub, indices_nb_ub, output_ne_ub, output_nb_ub);
+    op.calculate();
+}
diff --git a/src/ggml-cann/kernels/get_row_q4_0.cpp b/src/ggml-cann/kernels/get_row_q4_0.cpp
new file mode 100644 (file)
index 0000000..a80bfee
--- /dev/null
@@ -0,0 +1,193 @@
+#include "kernel_operator.h"
+
+// optimize me. Use template to avoid copy code.
+using namespace AscendC;
+
+#define BUFFER_NUM 2
+
+#define QK4_0 32
+
+class GET_ROW_Q4_0 {
+   public:
+    __aicore__ inline GET_ROW_Q4_0() {}
+    __aicore__ inline void init(GM_ADDR input, GM_ADDR indices, GM_ADDR output,
+                                int64_t *input_ne_ub, int64_t *indices_ne_ub,
+                                size_t *indices_nb_ub, int64_t *output_ne_ub,
+                                size_t *output_nb_ub) {
+        int64_t op_block_num = GetBlockNum();
+        int64_t op_block_idx = GetBlockIdx();
+
+        for (int i = 0; i < 4; i++) {
+            input_ne[i] = input_ne_ub[i];
+            indices_ne[i] = indices_ne_ub[i];
+            indices_stride[i] = indices_nb_ub[i] / indices_nb_ub[0];
+            scale_ne[i] = input_ne_ub[i];
+            output_ne[i] = output_ne_ub[i];
+            output_stride[i] = output_nb_ub[i] / output_nb_ub[0];
+        }
+
+        // one scale for a group.
+        scale_ne[0] /= QK4_0;
+
+        input_stride[0] = 1;
+        scale_stride[0] = 1;
+        output_stride[0] = 1;
+        for (int i = 1; i < 4; i++) {
+            input_stride[i] = input_stride[i - 1] * input_ne[i - 1];
+            scale_stride[i] = scale_stride[i - 1] * scale_ne[i - 1];
+        }
+
+        group_size_in_row = input_ne[0] / QK4_0;
+        int64_t scale_offset = input_ne[0] * input_ne[1] * input_ne[2] *
+                               input_ne[3] / 2;
+
+        // Indices has two dims. n_elements = all rows should get.
+        // dr, all rows should this thread get.
+        uint64_t n_elements =
+            indices_ne[0] * indices_ne[1] * indices_ne[2] * indices_ne[3];
+        dr = n_elements / op_block_num;
+
+        uint64_t tails = n_elements % op_block_num;
+        if (op_block_idx < tails) {
+            dr += 1;
+            ir = dr * op_block_idx;
+        } else {
+            ir = dr * op_block_idx + tails;
+        }
+
+        input_gm.SetGlobalBuffer((__gm__ int4b_t *)input);
+        scale_gm.SetGlobalBuffer((__gm__ half *)(input + scale_offset));
+        indices_gm.SetGlobalBuffer((__gm__ int32_t *)indices);
+        output_gm.SetGlobalBuffer((__gm__ float *)output);
+
+        pipe.InitBuffer(input_queue, BUFFER_NUM, QK4_0 * sizeof(int4b_t));
+        pipe.InitBuffer(cast_queue, BUFFER_NUM, QK4_0 * sizeof(half));
+        pipe.InitBuffer(output_queue, BUFFER_NUM, QK4_0 * sizeof(float));
+    }
+
+    __aicore__ inline void copy_in(uint32_t offset) {
+        LocalTensor<int4b_t> input_local = input_queue.AllocTensor<int4b_t>();
+        // 32 * sizeof(int4b_t) = 16, which is not aligned to 32, why no error?
+        DataCopy(input_local, input_gm[offset], QK4_0);
+        input_queue.EnQue(input_local);
+    }
+
+    __aicore__ inline void copy_out(uint32_t offset) {
+        LocalTensor<float> output_local = output_queue.DeQue<float>();
+        DataCopy(output_gm[offset], output_local, QK4_0);
+        output_queue.FreeTensor(output_local);
+    }
+
+    __aicore__ inline void calculate_group(int64_t idx, int64_t group) {
+        const int64_t indices_ne2_idx = idx / (indices_ne[0] * indices_ne[1]);
+        const int64_t indices_ne1_idx =
+            (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1]) /
+            indices_ne[0];
+        const int64_t indices_ne0_idx =
+            (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1] -
+             indices_ne1_idx * indices_ne[0]);
+
+        const int64_t indices_offset = indices_ne0_idx * indices_stride[0] +
+                                       indices_ne1_idx * indices_stride[1] +
+                                       indices_ne2_idx * indices_stride[2];
+        const int32_t selected_row_idx = indices_gm.GetValue(indices_offset);
+
+        const int64_t input_offset = selected_row_idx * input_stride[1] +
+                                     indices_ne1_idx * input_stride[2] +
+                                     indices_ne2_idx * input_stride[3] +
+                                     group * QK4_0;
+        const int64_t scale_offset = selected_row_idx * scale_stride[1] +
+                                     indices_ne1_idx * scale_stride[2] +
+                                     indices_ne2_idx * scale_stride[3] + group;
+        const int64_t output_offset = indices_ne0_idx * output_stride[1] +
+                                      indices_ne1_idx * output_stride[2] +
+                                      indices_ne2_idx * output_stride[3] +
+                                      group * QK4_0;
+
+        copy_in(input_offset);
+        LocalTensor<int4b_t> input_local = input_queue.DeQue<int4b_t>();
+        LocalTensor<half> cast_local = cast_queue.AllocTensor<half>();
+        LocalTensor<float> output_local = output_queue.AllocTensor<float>();
+
+        // TODO: cast more data to speed up.
+        Cast(cast_local, input_local, RoundMode::CAST_NONE, QK4_0);
+        Cast(output_local, cast_local, RoundMode::CAST_NONE, QK4_0);
+
+        // Only mul need compile by group.
+        half scale = scale_gm.GetValue(scale_offset);
+
+        Muls(output_local, output_local, (float)scale, QK4_0);
+
+        input_queue.FreeTensor(input_local);
+        cast_queue.FreeTensor(cast_local);
+        output_queue.EnQue(output_local);
+
+        copy_out(output_offset);
+    }
+
+    __aicore__ inline void calculate() {
+        for (int64_t i = ir; i < ir + dr; i++) {
+            for (int64_t j = 0; j < group_size_in_row; j++) {
+                calculate_group(i, j);
+            }
+        }
+    }
+
+   private:
+    int64_t input_ne[4];
+    size_t input_stride[4];
+
+    int64_t scale_ne[4];
+    size_t scale_stride[4];
+
+    int64_t indices_ne[4];
+    size_t indices_stride[4];
+
+    int64_t output_ne[4];
+    size_t output_stride[4];
+
+    int64_t ir;
+    int64_t dr;
+
+    int64_t group_size_in_row;
+
+    TPipe pipe;
+    GlobalTensor<int4b_t> input_gm;
+    GlobalTensor<half> scale_gm;
+    GlobalTensor<int32_t> indices_gm;
+    GlobalTensor<float> output_gm;
+    TQue<QuePosition::VECIN, BUFFER_NUM> input_queue;
+    TQue<QuePosition::VECOUT, BUFFER_NUM> output_queue;
+    TQue<QuePosition::VECIN, BUFFER_NUM> cast_queue;
+};
+
+template <typename T>
+__aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) {
+    auto gm_ptr = (__gm__ uint8_t *)gm;
+    auto ub_ptr = (uint8_t *)(ub);
+    for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) {
+        *ub_ptr = *gm_ptr;
+    }
+}
+
+extern "C" __global__ __aicore__ void ascendc_get_row_q4_0(
+    GM_ADDR input_gm, GM_ADDR indices_gm, GM_ADDR output_gm,
+    GM_ADDR input_ne_gm, GM_ADDR indices_ne_gm, GM_ADDR indices_nb_gm,
+    GM_ADDR output_ne_gm, GM_ADDR output_nb_gm) {
+    int64_t input_ne_ub[4];
+    int64_t indices_ne_ub[4];
+    size_t indices_nb_ub[4];
+    int64_t output_ne_ub[4];
+    size_t output_nb_ub[4];
+
+    copy_to_ub(input_ne_gm, input_ne_ub, 32);
+    copy_to_ub(indices_ne_gm, indices_ne_ub, 32);
+    copy_to_ub(indices_nb_gm, indices_nb_ub, 32);
+    copy_to_ub(output_ne_gm, output_ne_ub, 32);
+    copy_to_ub(output_nb_gm, output_nb_ub, 32);
+
+    GET_ROW_Q4_0 op;
+    op.init(input_gm, indices_gm, output_gm, input_ne_ub, indices_ne_ub,
+            indices_nb_ub, output_ne_ub, output_nb_ub);
+    op.calculate();
+}
diff --git a/src/ggml-cann/kernels/get_row_q8_0.cpp b/src/ggml-cann/kernels/get_row_q8_0.cpp
new file mode 100644 (file)
index 0000000..ba9ab3c
--- /dev/null
@@ -0,0 +1,191 @@
+#include "kernel_operator.h"
+
+// optimize me. Use template to avoid copy code.
+using namespace AscendC;
+
+#define BUFFER_NUM 2
+
+#define QK8_0 32
+
+class GET_ROW_Q8_0 {
+   public:
+    __aicore__ inline GET_ROW_Q8_0() {}
+    __aicore__ inline void init(GM_ADDR input, GM_ADDR indices, GM_ADDR output,
+                                int64_t *input_ne_ub, int64_t *indices_ne_ub,
+                                size_t *indices_nb_ub, int64_t *output_ne_ub,
+                                size_t *output_nb_ub) {
+        int64_t op_block_num = GetBlockNum();
+        int64_t op_block_idx = GetBlockIdx();
+
+        for (int i = 0; i < 4; i++) {
+            input_ne[i] = input_ne_ub[i];
+            indices_ne[i] = indices_ne_ub[i];
+            indices_stride[i] = indices_nb_ub[i] / indices_nb_ub[0];
+            scale_ne[i] = input_ne_ub[i];
+            output_ne[i] = output_ne_ub[i];
+            output_stride[i] = output_nb_ub[i] / output_nb_ub[0];
+        }
+
+        // one scale for a group.
+        scale_ne[0] /= QK8_0;
+
+        input_stride[0] = 1;
+        scale_stride[0] = 1;
+        output_stride[0] = 1;
+        for (int i = 1; i < 4; i++) {
+            input_stride[i] = input_stride[i - 1] * input_ne[i - 1];
+            scale_stride[i] = scale_stride[i - 1] * scale_ne[i - 1];
+        }
+
+        group_size_in_row = input_ne[0] / QK8_0;
+        int64_t scale_offset = input_ne[0] * input_ne[1] * input_ne[2] *
+                               input_ne[3] * sizeof(int8_t);
+
+        // Indices has two dims. n_elements = all rows should get.
+        // dr, all rows should this thread get.
+        uint64_t n_elements =
+            indices_ne[0] * indices_ne[1] * indices_ne[2] * indices_ne[3];
+        dr = n_elements / op_block_num;
+
+        uint64_t tails = n_elements % op_block_num;
+        if (op_block_idx < tails) {
+            dr += 1;
+            ir = dr * op_block_idx;
+        } else {
+            ir = dr * op_block_idx + tails;
+        }
+
+        input_gm.SetGlobalBuffer((__gm__ int8_t *)input);
+        scale_gm.SetGlobalBuffer((__gm__ half *)(input + scale_offset));
+        indices_gm.SetGlobalBuffer((__gm__ int32_t *)indices);
+        output_gm.SetGlobalBuffer((__gm__ float *)output);
+
+        pipe.InitBuffer(input_queue, BUFFER_NUM, QK8_0 * sizeof(int8_t));
+        pipe.InitBuffer(cast_queue, BUFFER_NUM, QK8_0 * sizeof(half));
+        pipe.InitBuffer(output_queue, BUFFER_NUM, QK8_0 * sizeof(float));
+    }
+
+    __aicore__ inline void copy_in(uint32_t offset) {
+        LocalTensor<int8_t> input_local = input_queue.AllocTensor<int8_t>();
+        DataCopy(input_local, input_gm[offset], QK8_0);
+        input_queue.EnQue(input_local);
+    }
+
+    __aicore__ inline void copy_out(uint32_t offset) {
+        LocalTensor<float> output_local = output_queue.DeQue<float>();
+        DataCopy(output_gm[offset], output_local, QK8_0);
+        output_queue.FreeTensor(output_local);
+    }
+
+    __aicore__ inline void calculate_group(int64_t idx, int64_t group) {
+        const int64_t indices_ne2_idx = idx / (indices_ne[0] * indices_ne[1]);
+        const int64_t indices_ne1_idx =
+            (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1]) /
+            indices_ne[0];
+        const int64_t indices_ne0_idx =
+            (idx - indices_ne2_idx * indices_ne[0] * indices_ne[1] -
+             indices_ne1_idx * indices_ne[0]);
+
+        const int64_t indices_offset = indices_ne0_idx * indices_stride[0] +
+                                       indices_ne1_idx * indices_stride[1] +
+                                       indices_ne2_idx * indices_stride[2];
+        const int32_t selected_row_idx = indices_gm.GetValue(indices_offset);
+
+        const int64_t input_offset = selected_row_idx * input_stride[1] +
+                                     indices_ne1_idx * input_stride[2] +
+                                     indices_ne2_idx * input_stride[3] +
+                                     group * QK8_0;
+        const int64_t scale_offset = selected_row_idx * scale_stride[1] +
+                                     indices_ne1_idx * scale_stride[2] +
+                                     indices_ne2_idx * scale_stride[3] + group;
+        const int64_t output_offset = indices_ne0_idx * output_stride[1] +
+                                      indices_ne1_idx * output_stride[2] +
+                                      indices_ne2_idx * output_stride[3] +
+                                      group * QK8_0;
+
+        copy_in(input_offset);
+        LocalTensor<int8_t> input_local = input_queue.DeQue<int8_t>();
+        LocalTensor<half> cast_local = cast_queue.AllocTensor<half>();
+        LocalTensor<float> output_local = output_queue.AllocTensor<float>();
+
+        // TODO: cast more data to speed up.
+        Cast(cast_local, input_local, RoundMode::CAST_NONE, QK8_0);
+        Cast(output_local, cast_local, RoundMode::CAST_NONE, QK8_0);
+
+        // Only mul need compile by group.
+        half scale = scale_gm.GetValue(scale_offset);
+        Muls(output_local, output_local, (float)scale, QK8_0);
+
+        input_queue.FreeTensor(input_local);
+        cast_queue.FreeTensor(cast_local);
+        output_queue.EnQue(output_local);
+
+        copy_out(output_offset);
+    }
+
+    __aicore__ inline void calculate() {
+        for (int64_t i = ir; i < ir + dr; i++) {
+            for (int64_t j = 0; j < group_size_in_row; j++) {
+                calculate_group(i, j);
+            }
+        }
+    }
+
+   private:
+    int64_t input_ne[4];
+    size_t input_stride[4];
+
+    int64_t scale_ne[4];
+    size_t scale_stride[4];
+
+    int64_t indices_ne[4];
+    size_t indices_stride[4];
+
+    int64_t output_ne[4];
+    size_t output_stride[4];
+
+    int64_t ir;
+    int64_t dr;
+
+    int64_t group_size_in_row;
+
+    TPipe pipe;
+    GlobalTensor<int8_t> input_gm;
+    GlobalTensor<half> scale_gm;
+    GlobalTensor<int32_t> indices_gm;
+    GlobalTensor<float> output_gm;
+    TQue<QuePosition::VECIN, BUFFER_NUM> input_queue;
+    TQue<QuePosition::VECOUT, BUFFER_NUM> output_queue;
+    TQue<QuePosition::VECIN, BUFFER_NUM> cast_queue;
+};
+
+template <typename T>
+__aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) {
+    auto gm_ptr = (__gm__ uint8_t *)gm;
+    auto ub_ptr = (uint8_t *)(ub);
+    for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) {
+        *ub_ptr = *gm_ptr;
+    }
+}
+
+extern "C" __global__ __aicore__ void ascendc_get_row_q8_0(
+    GM_ADDR input_gm, GM_ADDR indices_gm, GM_ADDR output_gm,
+    GM_ADDR input_ne_gm, GM_ADDR indices_ne_gm, GM_ADDR indices_nb_gm,
+    GM_ADDR output_ne_gm, GM_ADDR output_nb_gm) {
+    int64_t input_ne_ub[4];
+    int64_t indices_ne_ub[4];
+    size_t indices_nb_ub[4];
+    int64_t output_ne_ub[4];
+    size_t output_nb_ub[4];
+
+    copy_to_ub(input_ne_gm, input_ne_ub, 32);
+    copy_to_ub(indices_ne_gm, indices_ne_ub, 32);
+    copy_to_ub(indices_nb_gm, indices_nb_ub, 32);
+    copy_to_ub(output_ne_gm, output_ne_ub, 32);
+    copy_to_ub(output_nb_gm, output_nb_ub, 32);
+
+    GET_ROW_Q8_0 op;
+    op.init(input_gm, indices_gm, output_gm, input_ne_ub, indices_ne_ub,
+            indices_nb_ub, output_ne_ub, output_nb_ub);
+    op.calculate();
+}
diff --git a/src/ggml-cann/kernels/quantize_f16_q8_0.cpp b/src/ggml-cann/kernels/quantize_f16_q8_0.cpp
new file mode 100644 (file)
index 0000000..8423b3f
--- /dev/null
@@ -0,0 +1,208 @@
+#include "kernel_operator.h"
+
+using namespace AscendC;
+
+#define BUFFER_NUM 2
+#define QK8_0 32
+
+class QUANTIZE_F16_Q8_0 {
+   public:
+    __aicore__ inline QUANTIZE_F16_Q8_0() {}
+    __aicore__ inline void init(GM_ADDR input, GM_ADDR output,
+                                int64_t *input_ne_ub, size_t *input_nb_ub,
+                                int64_t *output_ne_ub) {
+        int64_t op_block_num = GetBlockNum();
+        int64_t op_block_idx = GetBlockIdx();
+
+        for (int i = 0; i < 4; i++) {
+            input_ne[i] = input_ne_ub[i];
+            input_stride[i] = input_nb_ub[i] / input_nb_ub[0];
+
+            output_ne[i] = output_ne_ub[i];
+        }
+
+        output_stride[0] = 1;
+        for (int i = 1; i < 4; i++) {
+            output_stride[i] = output_stride[i - 1] * output_ne[i - 1];
+        }
+
+        scale_ne = input_ne;
+        scale_stride[0] = 1;
+        scale_stride[1] = input_ne[0] / QK8_0;
+        for (int i = 2; i < 4; i++) {
+            scale_stride[i] = scale_stride[i - 1] * scale_ne[i - 1];
+        }
+
+        // split input tensor by rows.
+        uint64_t nr = input_ne[1] * input_ne[2] * input_ne[3];
+        dr = nr / op_block_num;
+
+        uint64_t tails = nr % op_block_num;
+        if (op_block_idx < tails) {
+            dr += 1;
+            ir = dr * op_block_idx;
+        } else {
+            ir = dr * op_block_idx + tails;
+        }
+
+        group_size_in_row = scale_stride[1];
+        int64_t output_size = output_ne[0] * output_ne[1] * output_ne[2] *
+                              output_ne[3] * sizeof(uint8_t);
+
+        input_gm.SetGlobalBuffer((__gm__ half *)input);
+        output_gm.SetGlobalBuffer((__gm__ int8_t *)output);
+        scale_gm.SetGlobalBuffer((__gm__ half *)(output + output_size + ir *
+                                                 group_size_in_row *
+                                                 sizeof(half)));
+
+        pipe.InitBuffer(input_queue, BUFFER_NUM, QK8_0 * sizeof(half));
+        pipe.InitBuffer(output_queue, BUFFER_NUM, QK8_0 * sizeof(int8_t));
+        pipe.InitBuffer(work_queue, 1, 32);
+        pipe.InitBuffer(max_queue, 1, 32);
+        pipe.InitBuffer(abs_queue, 1, QK8_0 * sizeof(float));
+        pipe.InitBuffer(scale_queue, 1, 32);
+        pipe.InitBuffer(cast_queue ,1 ,QK8_0 * sizeof(float));
+    }
+
+    __aicore__ inline void copy_in(uint32_t offset) {
+        LocalTensor<half> input_local = input_queue.AllocTensor<half>();
+        DataCopy(input_local, input_gm[offset], QK8_0);
+        input_queue.EnQue(input_local);
+    }
+
+    __aicore__ inline void copy_out(uint32_t offset) {
+        LocalTensor<int8_t> output_local = output_queue.DeQue<int8_t>();
+        DataCopy(output_gm[offset], output_local, QK8_0);
+        output_queue.FreeTensor(output_local);
+    }
+
+    __aicore__ inline half calculate_group(int64_t row, int64_t group) {
+        const int64_t i3 = row / (input_ne[1] * input_ne[2]);
+        const int64_t i2 = (row - i3 * input_ne[1] * input_ne[2]) / input_ne[1];
+        const int64_t i1 =
+            row - i3 * input_ne[1] * input_ne[2] - i2 * input_ne[1];
+
+        const int64_t input_offset = i1 * input_stride[1] +
+                                     i2 * input_stride[2] +
+                                     i3 * input_stride[3] + QK8_0 * group;
+
+        const int64_t output_offset = i1 * output_stride[1] +
+                                      i2 * output_stride[2] +
+                                      i3 * output_stride[3] + QK8_0 * group;
+
+        copy_in(input_offset);
+        LocalTensor<half> input_local = input_queue.DeQue<half>();
+        LocalTensor<int8_t> output_local = output_queue.AllocTensor<int8_t>();
+        LocalTensor<float> work_local = work_queue.AllocTensor<float>();
+        LocalTensor<float> abs_local = abs_queue.AllocTensor<float>();
+        LocalTensor<float> max_local = max_queue.AllocTensor<float>();
+        LocalTensor<float> cast_local = cast_queue.AllocTensor<float>();
+
+        Cast(cast_local, input_local, RoundMode::CAST_NONE, QK8_0);
+        Abs(abs_local, cast_local, QK8_0);
+        ReduceMax(max_local, abs_local, work_local, QK8_0);
+
+        pipe_barrier(PIPE_ALL);
+        float d = max_local.GetValue(0);
+        d = d / ((1 << 7) - 1);
+        if (d != 0) {
+            Muls(cast_local, cast_local, 1.0f / d, QK8_0);
+        }
+
+        Cast(cast_local, cast_local, RoundMode::CAST_ROUND, QK8_0);
+        Cast(input_local, cast_local, RoundMode::CAST_ROUND, QK8_0);
+        Cast(output_local, input_local, RoundMode::CAST_ROUND, QK8_0);
+        output_queue.EnQue(output_local);
+        copy_out(output_offset);
+
+        input_queue.FreeTensor(input_local);
+        work_queue.FreeTensor(work_local);
+        abs_queue.FreeTensor(abs_local);
+        max_queue.FreeTensor(max_local);
+        cast_queue.FreeTensor(cast_local);
+        return (half)d;
+    }
+
+    __aicore__ inline void calculate() {
+        LocalTensor<half> scale_local = scale_queue.AllocTensor<half>();
+        uint32_t scale_local_offset = 0;
+        uint32_t scale_global_offset = 0;
+        for (int64_t i = ir; i < ir + dr; i++) {
+            for (int64_t j = 0; j < group_size_in_row; j++) {
+                half scale = calculate_group(i, j);
+                scale_local.SetValue(scale_local_offset++, scale);
+                if (scale_local_offset == 16) {
+                    scale_local_offset = 0;
+                    // TODO: OPTIMIZE ME
+                    pipe_barrier(PIPE_ALL);
+                    DataCopy(scale_gm[scale_global_offset], scale_local, 16);
+                    pipe_barrier(PIPE_ALL);
+                    scale_global_offset += 16;
+                }
+            }
+        }
+
+        if (scale_local_offset != 0) {
+            pipe_barrier(PIPE_ALL);
+            DataCopyExtParams dataCopyParams;
+            dataCopyParams.blockCount = 1;
+            dataCopyParams.blockLen = scale_local_offset * sizeof(half);
+            DataCopyPad(scale_gm[scale_global_offset], scale_local,
+                        dataCopyParams);
+            pipe_barrier(PIPE_ALL);
+        }
+    }
+
+   private:
+    int64_t input_ne[4];
+    size_t input_stride[4];
+
+    int64_t *scale_ne;
+    size_t scale_stride[4];
+
+    int64_t output_ne[4];
+    size_t output_stride[4];
+
+    int64_t group_size_in_row;
+
+    int64_t ir;
+    int64_t dr;
+
+    TPipe pipe;
+    GlobalTensor<half> input_gm;
+    GlobalTensor<half> scale_gm;
+    GlobalTensor<int8_t> output_gm;
+    TQue<QuePosition::VECIN, BUFFER_NUM> input_queue;
+    TQue<QuePosition::VECOUT, BUFFER_NUM> output_queue;
+    TQue<QuePosition::VECIN, 1> work_queue;
+    TQue<QuePosition::VECOUT, 1> max_queue;
+    TQue<QuePosition::VECIN, 1> abs_queue;
+    TQue<QuePosition::VECOUT, 1> scale_queue;
+    TQue<QuePosition::VECOUT, 1> cast_queue;
+
+};
+
+template <typename T>
+__aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) {
+    auto gm_ptr = (__gm__ uint8_t *)gm;
+    auto ub_ptr = (uint8_t *)(ub);
+    for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) {
+        *ub_ptr = *gm_ptr;
+    }
+}
+
+extern "C" __global__ __aicore__ void ascendc_quantize_f16_q8_0(
+    GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm,
+    GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) {
+    int64_t input_ne_ub[4];
+    size_t input_nb_ub[4];
+    int64_t output_ne_ub[4];
+
+    copy_to_ub(input_ne_gm, input_ne_ub, 32);
+    copy_to_ub(input_nb_gm, input_nb_ub, 32);
+    copy_to_ub(output_ne_gm, output_ne_ub, 32);
+
+    QUANTIZE_F16_Q8_0 op;
+    op.init(input_gm, output_gm, input_ne_ub, input_nb_ub, output_ne_ub);
+    op.calculate();
+}
diff --git a/src/ggml-cann/kernels/quantize_f32_q8_0.cpp b/src/ggml-cann/kernels/quantize_f32_q8_0.cpp
new file mode 100644 (file)
index 0000000..b7c5750
--- /dev/null
@@ -0,0 +1,206 @@
+#include "kernel_operator.h"
+
+using namespace AscendC;
+
+#define BUFFER_NUM 2
+#define QK8_0 32
+
+class QUANTIZE_F32_Q8_0 {
+   public:
+    __aicore__ inline QUANTIZE_F32_Q8_0() {}
+    __aicore__ inline void init(GM_ADDR input, GM_ADDR output,
+                                int64_t *input_ne_ub, size_t *input_nb_ub,
+                                int64_t *output_ne_ub) {
+        int64_t op_block_num = GetBlockNum();
+        int64_t op_block_idx = GetBlockIdx();
+
+        for (int i = 0; i < 4; i++) {
+            input_ne[i] = input_ne_ub[i];
+            input_stride[i] = input_nb_ub[i] / input_nb_ub[0];
+
+            output_ne[i] = output_ne_ub[i];
+        }
+
+        output_stride[0] = 1;
+        for (int i = 1; i < 4; i++) {
+            output_stride[i] = output_stride[i - 1] * output_ne[i - 1];
+        }
+
+        scale_ne = input_ne;
+        scale_stride[0] = 1;
+        scale_stride[1] = input_ne[0] / QK8_0;
+        for (int i = 2; i < 4; i++) {
+            scale_stride[i] = scale_stride[i - 1] * scale_ne[i - 1];
+        }
+
+        // split input tensor by rows.
+        uint64_t nr = input_ne[1] * input_ne[2] * input_ne[3];
+        dr = nr / op_block_num;
+
+        uint64_t tails = nr % op_block_num;
+        if (op_block_idx < tails) {
+            dr += 1;
+            ir = dr * op_block_idx;
+        } else {
+            ir = dr * op_block_idx + tails;
+        }
+
+        group_size_in_row = scale_stride[1];
+        int64_t output_size = output_ne[0] * output_ne[1] * output_ne[2] *
+                              output_ne[3] * sizeof(uint8_t);
+
+        input_gm.SetGlobalBuffer((__gm__ float *)input);
+        output_gm.SetGlobalBuffer((__gm__ int8_t *)output);
+        scale_gm.SetGlobalBuffer((__gm__ half *)(output + output_size +
+                                                 ir * group_size_in_row *
+                                                 sizeof(half)));
+
+        pipe.InitBuffer(input_queue, BUFFER_NUM, QK8_0 * sizeof(float));
+        pipe.InitBuffer(output_queue, BUFFER_NUM, QK8_0 * sizeof(int8_t));
+        pipe.InitBuffer(work_queue, 1, 32);
+        pipe.InitBuffer(max_queue, 1, 32);
+        pipe.InitBuffer(abs_queue, 1, QK8_0 * sizeof(float));
+        pipe.InitBuffer(cast_queue, 1, QK8_0 * sizeof(half));
+        pipe.InitBuffer(scale_queue, 1, 32);
+    }
+
+    __aicore__ inline void copy_in(uint32_t offset) {
+        LocalTensor<float> input_local = input_queue.AllocTensor<float>();
+        DataCopy(input_local, input_gm[offset], QK8_0);
+        input_queue.EnQue(input_local);
+    }
+
+    __aicore__ inline void copy_out(uint32_t offset) {
+        LocalTensor<int8_t> output_local = output_queue.DeQue<int8_t>();
+        DataCopy(output_gm[offset], output_local, QK8_0);
+        output_queue.FreeTensor(output_local);
+    }
+
+    __aicore__ inline half calculate_group(int64_t row, int64_t group) {
+        const int64_t i3 = row / (input_ne[1] * input_ne[2]);
+        const int64_t i2 = (row - i3 * input_ne[1] * input_ne[2]) / input_ne[1];
+        const int64_t i1 =
+            row - i3 * input_ne[1] * input_ne[2] - i2 * input_ne[1];
+
+        const int64_t input_offset = i1 * input_stride[1] +
+                                     i2 * input_stride[2] +
+                                     i3 * input_stride[3] + QK8_0 * group;
+
+        const int64_t output_offset = i1 * output_stride[1] +
+                                      i2 * output_stride[2] +
+                                      i3 * output_stride[3] + QK8_0 * group;
+
+        copy_in(input_offset);
+        LocalTensor<float> input_local = input_queue.DeQue<float>();
+        LocalTensor<int8_t> output_local = output_queue.AllocTensor<int8_t>();
+        LocalTensor<float> work_local = work_queue.AllocTensor<float>();
+        LocalTensor<float> abs_local = abs_queue.AllocTensor<float>();
+        LocalTensor<float> max_local = max_queue.AllocTensor<float>();
+        LocalTensor<half> cast_local = cast_queue.AllocTensor<half>();
+
+        Abs(abs_local, input_local, QK8_0);
+        ReduceMax(max_local, abs_local, work_local, QK8_0);
+        pipe_barrier(PIPE_ALL);
+        float d = max_local.GetValue(0);
+        d = d / ((1 << 7) - 1);
+        if (d != 0) {
+            Muls(input_local, input_local, 1.0f / d, QK8_0);
+        }
+
+        Cast(input_local, input_local, RoundMode::CAST_ROUND, QK8_0);
+        Cast(cast_local, input_local, RoundMode::CAST_ROUND, QK8_0);
+        Cast(output_local, cast_local, RoundMode::CAST_ROUND, QK8_0);
+        output_queue.EnQue(output_local);
+        copy_out(output_offset);
+
+        input_queue.FreeTensor(input_local);
+        work_queue.FreeTensor(work_local);
+        abs_queue.FreeTensor(abs_local);
+        max_queue.FreeTensor(max_local);
+        cast_queue.FreeTensor(cast_local);
+
+        return (half)d;
+    }
+
+    __aicore__ inline void calculate() {
+        LocalTensor<half> scale_local = scale_queue.AllocTensor<half>();
+        uint32_t scale_local_offset = 0;
+        uint32_t scale_global_offset = 0;
+        for (int64_t i = ir; i < ir + dr; i++) {
+            for (int64_t j = 0; j < group_size_in_row; j++) {
+                half scale = calculate_group(i, j);
+                scale_local.SetValue(scale_local_offset++, scale);
+                if (scale_local_offset == 16) {
+                    scale_local_offset = 0;
+                    // TODO: OPTIMIZE ME
+                    pipe_barrier(PIPE_ALL);
+                    DataCopy(scale_gm[scale_global_offset], scale_local, 16);
+                    pipe_barrier(PIPE_ALL);
+                    scale_global_offset += 16;
+                }
+            }
+        }
+
+        if (scale_local_offset != 0) {
+            pipe_barrier(PIPE_ALL);
+            DataCopyExtParams dataCopyParams;
+            dataCopyParams.blockCount = 1;
+            dataCopyParams.blockLen = scale_local_offset * sizeof(half);
+            DataCopyPad(scale_gm[scale_global_offset], scale_local,
+                        dataCopyParams);
+            pipe_barrier(PIPE_ALL);
+        }
+    }
+
+   private:
+    int64_t input_ne[4];
+    size_t input_stride[4];
+
+    int64_t *scale_ne;
+    size_t scale_stride[4];
+
+    int64_t output_ne[4];
+    size_t output_stride[4];
+
+    int64_t group_size_in_row;
+
+    int64_t ir;
+    int64_t dr;
+
+    TPipe pipe;
+    GlobalTensor<float> input_gm;
+    GlobalTensor<half> scale_gm;
+    GlobalTensor<int8_t> output_gm;
+    TQue<QuePosition::VECIN, BUFFER_NUM> input_queue;
+    TQue<QuePosition::VECOUT, BUFFER_NUM> output_queue;
+    TQue<QuePosition::VECIN, 1> work_queue;
+    TQue<QuePosition::VECOUT, 1> max_queue;
+    TQue<QuePosition::VECIN, 1> abs_queue;
+    TQue<QuePosition::VECIN, 1> cast_queue;
+    TQue<QuePosition::VECOUT, 1> scale_queue;
+};
+
+template <typename T>
+__aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) {
+    auto gm_ptr = (__gm__ uint8_t *)gm;
+    auto ub_ptr = (uint8_t *)(ub);
+    for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) {
+        *ub_ptr = *gm_ptr;
+    }
+}
+
+extern "C" __global__ __aicore__ void ascendc_quantize_f32_q8_0(
+    GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm,
+    GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) {
+    int64_t input_ne_ub[4];
+    size_t input_nb_ub[4];
+    int64_t output_ne_ub[4];
+
+    copy_to_ub(input_ne_gm, input_ne_ub, 32);
+    copy_to_ub(input_nb_gm, input_nb_ub, 32);
+    copy_to_ub(output_ne_gm, output_ne_ub, 32);
+
+    QUANTIZE_F32_Q8_0 op;
+    op.init(input_gm, output_gm, input_ne_ub, input_nb_ub, output_ne_ub);
+    op.calculate();
+}
diff --git a/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp b/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp
new file mode 100644 (file)
index 0000000..9c8c86b
--- /dev/null
@@ -0,0 +1,278 @@
+#include "kernel_operator.h"
+
+using namespace AscendC;
+
+#define BUFFER_NUM 2
+#define Group_Size 32
+
+template <typename SRC_T>
+class QUANTIZE_FLOAT_TO_Q4_0 {
+   public:
+    __aicore__ inline QUANTIZE_FLOAT_TO_Q4_0() {}
+    __aicore__ inline void init(GM_ADDR input, GM_ADDR output,
+                                int64_t *input_ne_ub, size_t *input_nb_ub,
+                                int64_t *output_ne_ub) {
+        // TODO: fix test_case CPY(type_src=f16,type_dst=q4_0,ne=[256,4,4,4],
+        //                         permute=[0,0,0,0]):
+        // [CPY] NMSE = 0.000008343 > 0.000001000 FAIL
+        int64_t op_block_num = GetBlockNum();
+        int64_t op_block_idx = GetBlockIdx();
+
+        // input stride of data elements
+        for (int i = 0; i < 4; i++) {
+            input_ne[i] = input_ne_ub[i];
+            input_stride[i] = input_nb_ub[i] / input_nb_ub[0];
+            output_ne[i] = output_ne_ub[i];
+        }
+
+        // output stride of data elements
+        output_stride[0] = 1;
+        for (int i = 1; i < 4; i++) {
+            output_stride[i] = output_stride[i - 1] * output_ne[i - 1];
+        }
+
+        // scale saved one by one after data:. [group1_scale, group2_scale, ...]
+        scale_ne = input_ne;
+        scale_stride[0] = 1;
+        scale_stride[1] = input_ne[0] / Group_Size;
+        for (int i = 2; i < 4; i++) {
+            scale_stride[i] = scale_stride[i - 1] * scale_ne[i - 1];
+        }
+
+        // split input tensor by rows.
+        uint64_t nr = input_ne[1] * input_ne[2] * input_ne[3];
+        dr = nr / op_block_num;
+
+        uint64_t tails = nr % op_block_num;
+        if (op_block_idx < tails) {
+            dr += 1;
+            ir = dr * op_block_idx;
+        } else {
+            ir = dr * op_block_idx + tails;
+        }
+
+        group_size_in_row = scale_stride[1];
+        int64_t scale_offset = output_ne[0] * output_ne[1] * output_ne[2] *
+                              output_ne[3] * sizeof(uint8_t) / 2;
+
+        input_gm.SetGlobalBuffer((__gm__ SRC_T *)input);
+        output_gm.SetGlobalBuffer((__gm__ int8_t *)output);
+        scale_gm.SetGlobalBuffer((__gm__ half *)(output + scale_offset + ir *
+                                                 group_size_in_row *
+                                                 sizeof(half)));
+
+        pipe.InitBuffer(input_queue, BUFFER_NUM, Group_Size * sizeof(SRC_T));
+        pipe.InitBuffer(output_queue, BUFFER_NUM,
+                            Group_Size * sizeof(int8_t) / 2);
+        pipe.InitBuffer(cast_queue , 1, Group_Size * sizeof(float));
+        pipe.InitBuffer(work_queue, 1, Group_Size * sizeof(float));
+        pipe.InitBuffer(max_queue, 1, Group_Size * sizeof(float));
+        pipe.InitBuffer(min_queue, 1, Group_Size * sizeof(float));
+        pipe.InitBuffer(scale_queue, 1, Group_Size / 2 * sizeof(half));
+        pipe.InitBuffer(int8_queue, 1, Group_Size * sizeof(int8_t));
+        pipe.InitBuffer(half_queue, 1, Group_Size * sizeof(half));
+    }
+
+    __aicore__ inline void copy_in(uint32_t offset) {
+        LocalTensor<SRC_T> input_local = input_queue.AllocTensor<SRC_T>();
+        DataCopy(input_local, input_gm[offset], Group_Size);
+        input_queue.EnQue(input_local);
+    }
+
+    __aicore__ inline void copy_out(uint32_t offset) {
+        // reinterpretcast Group_Size(32) * int4b_t to Group_Size / 2 * int8_t,
+        // and using DataCopyPad to avoid 32 bits align.
+        LocalTensor<int4b_t> output_local = output_queue.DeQue<int4b_t>();
+        LocalTensor<int8_t> output_int8_local =
+                                    output_local.ReinterpretCast<int8_t>();
+
+        DataCopyExtParams dataCopyParams;
+        dataCopyParams.blockCount = 1;
+        dataCopyParams.blockLen = Group_Size / 2  * sizeof(int8_t);
+        DataCopyPad(output_gm[offset], output_int8_local, dataCopyParams);
+
+        output_queue.FreeTensor(output_local);
+    }
+
+    __aicore__ inline void input_to_cast(LocalTensor<float> cast_local,
+                                         LocalTensor<float> input_local) {
+        DataCopy(cast_local, input_local, Group_Size);
+    }
+
+    __aicore__ inline void input_to_cast(LocalTensor<float> cast_local,
+                                         LocalTensor<half> input_local) {
+        Cast(cast_local, input_local, RoundMode::CAST_NONE, Group_Size);
+    }
+
+    __aicore__ inline half calculate_group(int64_t row, int64_t group) {
+        const int64_t i3 = row / (input_ne[1] * input_ne[2]);
+        const int64_t i2 = (row - i3 * input_ne[1] * input_ne[2]) / input_ne[1];
+        const int64_t i1 =
+            row - i3 * input_ne[1] * input_ne[2] - i2 * input_ne[1];
+
+        const int64_t input_offset = i1 * input_stride[1] +
+                                     i2 * input_stride[2] +
+                                     i3 * input_stride[3] + Group_Size * group;
+
+        // output_offset is stride for output_gm which datatype is int8_t and
+        // divided by 2 is needed for int4b_t.
+        const int64_t output_offset = (i1 * output_stride[1] +
+                                       i2 * output_stride[2] +
+                                       i3 * output_stride[3] +
+                                       Group_Size * group) / 2;
+        copy_in(input_offset);
+
+        LocalTensor<SRC_T> input_local = input_queue.DeQue<SRC_T>();
+        LocalTensor<int4b_t> output_local = output_queue.AllocTensor<int4b_t>();
+        LocalTensor<float> cast_local = cast_queue.AllocTensor<float>();
+        LocalTensor<float> work_local = work_queue.AllocTensor<float>();
+        LocalTensor<float> max_local = max_queue.AllocTensor<float>();
+        LocalTensor<float> min_local = min_queue.AllocTensor<float>();
+        LocalTensor<int8_t> int8_local = int8_queue.AllocTensor<int8_t>();
+        LocalTensor<half> half_local = half_queue.AllocTensor<half>();
+
+        input_to_cast(cast_local, input_local);
+
+        ReduceMax(max_local, cast_local, work_local, Group_Size);
+        ReduceMin(min_local, cast_local, work_local, Group_Size);
+        const float max_value = max_local.GetValue(0);
+        const float min_value = min_local.GetValue(0);
+        float d = max_value;
+        if (min_value < 0 && (-1 * min_value) > max_value) {
+            d = min_value;
+        }
+
+        d = d / (-8);
+        if (d != 0) {
+            Muls(cast_local, cast_local, 1.0f / d, Group_Size);
+        }
+
+        // range: [-8,8] -> [0.5,16.5] -> [0,16] -> [0,15] -> [-8,7]
+        float scalar = 8.5f;
+        Adds(cast_local, cast_local, scalar, Group_Size);
+        Cast(cast_local, cast_local, RoundMode::CAST_FLOOR, Group_Size);
+        scalar = 15.0f;
+        Mins(cast_local, cast_local, scalar, Group_Size);
+        scalar = -8.0f;
+        Adds(cast_local, cast_local, scalar, Group_Size);
+
+        // float->half->int4b
+        Cast(half_local, cast_local, RoundMode::CAST_NONE, Group_Size);
+        Cast(output_local, half_local, RoundMode::CAST_NONE, Group_Size);
+
+        output_queue.EnQue(output_local);
+        copy_out(output_offset);
+
+        input_queue.FreeTensor(input_local);
+        work_queue.FreeTensor(work_local);
+        max_queue.FreeTensor(max_local);
+        min_queue.FreeTensor(min_local);
+        int8_queue.FreeTensor(int8_local);
+        half_queue.FreeTensor(half_local);
+        cast_queue.FreeTensor(cast_local);
+        return (half)d;
+    }
+
+    __aicore__ inline void calculate() {
+        LocalTensor<half> scale_local = scale_queue.AllocTensor<half>();
+        uint32_t scale_local_offset = 0;
+        uint32_t scale_global_offset = 0;
+        for (int64_t i = ir; i < ir + dr; i++) {
+            for (int64_t j = 0; j < group_size_in_row; j++) {
+                half scale = calculate_group(i, j);
+                scale_local.SetValue(scale_local_offset++, scale);
+                // Copy Group_Size/2 length data each time.
+                if (scale_local_offset == Group_Size / 2) {
+                    scale_local_offset = 0;
+                    // TODO: OPTIMIZE ME
+                    pipe_barrier(PIPE_ALL);
+                    DataCopy(scale_gm[scale_global_offset], scale_local,
+                                      Group_Size / 2);
+                    pipe_barrier(PIPE_ALL);
+                    scale_global_offset += Group_Size / 2;
+                }
+            }
+        }
+
+        if (scale_local_offset != 0) {
+            pipe_barrier(PIPE_ALL);
+            DataCopyExtParams dataCopyParams;
+            dataCopyParams.blockCount = 1;
+            dataCopyParams.blockLen = scale_local_offset * sizeof(half);
+            DataCopyPad(scale_gm[scale_global_offset], scale_local,
+                        dataCopyParams);
+            pipe_barrier(PIPE_ALL);
+        }
+        scale_queue.FreeTensor(scale_local);
+    }
+
+   private:
+    int64_t input_ne[4];
+    size_t input_stride[4];
+
+    int64_t *scale_ne;
+    size_t scale_stride[4];
+
+    int64_t output_ne[4];
+    size_t output_stride[4];
+
+    int64_t group_size_in_row;
+
+    int64_t ir;
+    int64_t dr;
+
+    TPipe pipe;
+    GlobalTensor<SRC_T> input_gm;
+    GlobalTensor<half> scale_gm;
+    GlobalTensor<int8_t> output_gm;
+    TQue<QuePosition::VECIN, BUFFER_NUM> input_queue;
+    TQue<QuePosition::VECOUT, BUFFER_NUM> output_queue;
+    TQue<QuePosition::VECIN, BUFFER_NUM> work_queue;
+    TQue<QuePosition::VECOUT, BUFFER_NUM> max_queue;
+    TQue<QuePosition::VECOUT, BUFFER_NUM> min_queue;
+    TQue<QuePosition::VECOUT, BUFFER_NUM> scale_queue;
+    TQue<QuePosition::VECOUT, BUFFER_NUM> cast_queue;
+    TQue<QuePosition::VECOUT, BUFFER_NUM> int8_queue;
+    TQue<QuePosition::VECOUT, BUFFER_NUM> half_queue;
+};
+
+template <typename T>
+__aicore__ inline void copy_to_ub(GM_ADDR gm, T *ub, size_t size) {
+    auto gm_ptr = (__gm__ uint8_t *)gm;
+    auto ub_ptr = (uint8_t *)(ub);
+    for (int32_t i = 0; i < size; ++i, ++ub_ptr, ++gm_ptr) {
+        *ub_ptr = *gm_ptr;
+    }
+}
+
+extern "C" __global__ __aicore__ void ascendc_quantize_f16_to_q4_0(
+    GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm,
+    GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) {
+    int64_t input_ne_ub[4];
+    size_t input_nb_ub[4];
+    int64_t output_ne_ub[4];
+
+    copy_to_ub(input_ne_gm, input_ne_ub, 32);
+    copy_to_ub(input_nb_gm, input_nb_ub, 32);
+    copy_to_ub(output_ne_gm, output_ne_ub, 32);
+
+    QUANTIZE_FLOAT_TO_Q4_0<half> op;
+    op.init(input_gm, output_gm, input_ne_ub, input_nb_ub, output_ne_ub);
+    op.calculate();
+}
+
+extern "C" __global__ __aicore__ void ascendc_quantize_f32_to_q4_0(
+    GM_ADDR input_gm, GM_ADDR output_gm, GM_ADDR input_ne_gm,
+    GM_ADDR input_nb_gm, GM_ADDR output_ne_gm) {
+    int64_t input_ne_ub[4];
+    size_t input_nb_ub[4];
+    int64_t output_ne_ub[4];
+
+    copy_to_ub(input_ne_gm, input_ne_ub, 32);
+    copy_to_ub(input_nb_gm, input_nb_ub, 32);
+    copy_to_ub(output_ne_gm, output_ne_ub, 32);
+
+    QUANTIZE_FLOAT_TO_Q4_0<float> op;
+    op.init(input_gm, output_gm, input_ne_ub, input_nb_ub, output_ne_ub);
+    op.calculate();
+}