]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CANN: Improve the Inferencing Performance for Ascend NPU Device (llama/10454)
authorShanshan Shen <redacted>
Tue, 26 Nov 2024 10:08:37 +0000 (18:08 +0800)
committerGeorgi Gerganov <redacted>
Sun, 8 Dec 2024 18:14:35 +0000 (20:14 +0200)
* improve inferencing performance for ascend npu.

Co-authored-by: Frank Mai <redacted>
* some modification after review

* some modifications after review

* restore some modifications

* restore some modifications

---------

Co-authored-by: shanshan shen <redacted>
Co-authored-by: Frank Mai <redacted>
ggml/src/ggml-cann/aclnn_ops.cpp
ggml/src/ggml-cann/common.h
ggml/src/ggml-cann/ggml-cann.cpp

index 6113b59f45c2a247ec28241b6c9c8ddf17e0c6bf..d7472ee3a55c71d26591571a572ed75b0a5aa519 100644 (file)
@@ -33,6 +33,8 @@
 #include <aclnnop/aclnn_group_norm.h>
 #include <aclnnop/aclnn_index_fill_tensor.h>
 #include <aclnnop/aclnn_layer_norm.h>
+#include <aclnnop/aclnn_mm.h>
+#include <aclnnop/aclnn_batch_matmul.h>
 #include <aclnnop/aclnn_matmul.h>
 #include <aclnnop/aclnn_max_pool.h>
 #include <aclnnop/aclnn_permute.h>
@@ -2423,7 +2425,6 @@ static void aclnn_mat_mul(ggml_backend_cann_context& ctx, aclTensor* acl_input,
                           aclTensor* acl_weight, aclTensor* acl_dst) {
     int8_t cube_math_type = 1;  // ALLOW_FP32_DOWN_PRECISION, when input is
                                 // fp32, atlas a2 will transpose it to HFLOAT32.
-
     uint64_t workspaceSize = 0;
     aclOpExecutor* executor;
     void* workspaceAddr = nullptr;
@@ -2441,6 +2442,80 @@ static void aclnn_mat_mul(ggml_backend_cann_context& ctx, aclTensor* acl_input,
         aclnnMatmul(workspaceAddr, workspaceSize, executor, ctx.stream()));
 }
 
+/**
+ * @brief Performs matrix multiplication of two 2D tensors.
+ *
+ * This function computes the matrix multiplication of the input tensor
+ * `acl_input` and the weight tensor `acl_weight`, and stores the result in the
+ * destination tensor `acl_dst`.
+ * The operation is defined as:
+ * \f[
+ *     \text {acl_dst}=\text {acl_input@acl_weight}
+ * \f]
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_input The input tensor for the matrix multiplication.
+ * @param acl_weight The weight tensor for the matrix multiplication.
+ * @param acl_dst The destination tensor where the result of the matrix
+ * multiplication will be stored.
+ */
+static void aclnn_mat_mul_2d(ggml_backend_cann_context& ctx, aclTensor* acl_input,
+                             aclTensor* acl_weight, aclTensor* acl_dst) {
+    int8_t cube_math_type = 2;
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    ACL_CHECK(aclnnMmGetWorkspaceSize(acl_input, acl_weight, acl_dst,
+                                      cube_math_type, &workspaceSize,
+                                      &executor));
+
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    ACL_CHECK(
+        aclnnMm(workspaceAddr, workspaceSize, executor, ctx.stream()));
+}
+
+/**
+ * @brief Performs matrix multiplication of two 3D tensors.
+ *
+ * This function computes the matrix multiplication of the input tensor
+ * `acl_input` and the weight tensor `acl_weight`, and stores the result in the
+ * destination tensor `acl_dst`.
+ * The operation is defined as:
+ * \f[
+ *     \text {acl_dst}=\text {acl_input@acl_weight}
+ * \f]
+ *
+ * @param ctx The context for the CANN backend operations.
+ * @param acl_input The input tensor for the matrix multiplication.
+ * @param acl_weight The weight tensor for the matrix multiplication.
+ * @param acl_dst The destination tensor where the result of the matrix
+ * multiplication will be stored.
+ */
+static void aclnn_mat_mul_3d(ggml_backend_cann_context& ctx, aclTensor* acl_input,
+                             aclTensor* acl_weight, aclTensor* acl_dst) {
+    int8_t cube_math_type = 2;
+    uint64_t workspaceSize = 0;
+    aclOpExecutor* executor;
+    void* workspaceAddr = nullptr;
+
+    ACL_CHECK(aclnnBatchMatMulGetWorkspaceSize(acl_input, acl_weight, acl_dst,
+                                               cube_math_type, &workspaceSize,
+                                               &executor));
+
+    if (workspaceSize > 0) {
+        ggml_cann_pool_alloc workspace_allocator(ctx.pool(), workspaceSize);
+        workspaceAddr = workspace_allocator.get();
+    }
+
+    ACL_CHECK(
+        aclnnBatchMatMul(workspaceAddr, workspaceSize, executor, ctx.stream()));
+}
+
 /**
  * @brief Performs matrix multiplication with floating-point precision on
  * tensors using the CANN backend.
@@ -2462,20 +2537,43 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx,
     // broadcast, when weight ne2 or ne3 is not 1, weight need repeat.
     BCAST_MUL_MAT_SHAPE(input, weight, dst);
 
-    // transpose weight: [1,2,3,4] -> [1,2,4,3]
-    int64_t transpose_ne[] = {bcast_weight_ne[1], bcast_weight_ne[0],
-                              bcast_weight_ne[2], bcast_weight_ne[3],
-                              bcast_weight_ne[4], bcast_weight_ne[5]};
-    size_t transpose_nb[] = {bcast_weight_nb[1], bcast_weight_nb[0],
-                             bcast_weight_nb[2], bcast_weight_nb[3],
-                             bcast_weight_nb[4], bcast_weight_nb[5]};
+    int64_t n_dims = bcast_dims;
+    if (bcast_input_ne[3] == bcast_weight_ne[3] && bcast_input_ne[3] == 1) {
+        if (bcast_input_ne[2] == 1 && bcast_weight_ne[2] == 1) {
+            n_dims = 2;
+        } else if (bcast_input_ne[2] == 1) {
+            n_dims = 3;
+        }
+    }
 
-    aclTensor* acl_weight_tensor =
-        ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, bcast_dims);
     aclTensor* acl_input_tensor =
-        ggml_cann_create_tensor(input, BCAST_MUL_MAT_PARAM(input));
-    aclTensor* acl_dst = ggml_cann_create_tensor(dst, BCAST_MUL_MAT_PARAM(dst));
-    aclnn_mat_mul(ctx, acl_input_tensor, acl_weight_tensor, acl_dst);
+        ggml_cann_create_tensor(input, bcast_input_ne, bcast_input_nb, n_dims);
+    int64_t transpose_ne[] = {
+        bcast_weight_ne[1], bcast_weight_ne[0],
+        bcast_weight_ne[2], bcast_weight_ne[3],
+        bcast_weight_ne[4], bcast_weight_ne[5]
+    };
+    size_t transpose_nb[] = {
+        bcast_weight_nb[1], bcast_weight_nb[0],
+        bcast_weight_nb[2], bcast_weight_nb[3],
+        bcast_weight_nb[4], bcast_weight_nb[5]
+    };
+    aclTensor* acl_weight_tensor =
+        ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims);
+    aclTensor* acl_dst =
+        ggml_cann_create_tensor(dst, bcast_dst_ne, bcast_dst_nb, n_dims);
+
+    switch (n_dims) {
+    case 2:
+        aclnn_mat_mul_2d(ctx, acl_input_tensor, acl_weight_tensor, acl_dst);
+        break;
+    case 3:
+        aclnn_mat_mul_3d(ctx, acl_input_tensor, acl_weight_tensor, acl_dst);
+        break;
+    default:
+        aclnn_mat_mul(ctx, acl_input_tensor, acl_weight_tensor, acl_dst);
+        break;
+    }
 
     ACL_CHECK(aclDestroyTensor(acl_weight_tensor));
     ACL_CHECK(aclDestroyTensor(acl_input_tensor));
@@ -2501,46 +2599,40 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
     ggml_tensor* src0 = dst->src[0];  // weight
     ggml_tensor* src1 = dst->src[1];  // input
 
-    // The shape of the weight is NCHW. Matrix multiplication uses HW dims. HC
-    // is regarded as batch. weight need transpose.
-    int64_t weight_ne[] = {src0->ne[1], src0->ne[0]};
+    // The shape of the weight is NCHW.
+    // Matrix multiplication uses HW dims.
+    // HC is regarded as batch.
+    // weight need transpose.
     float weight_elem_size;
     if (type == GGML_TYPE_Q4_0) {
         weight_elem_size = float(sizeof(uint8_t)) / 2;
-    }
-    else if (type == GGML_TYPE_Q8_0) {
+    } else if (type == GGML_TYPE_Q8_0) {
         weight_elem_size = float(sizeof(uint8_t));
-    }
-    else {
+    } else {
         GGML_ABORT("Only support Q4_0 and Q8_0 MUL_MAT");
     }
-    float weight_nb[] = {weight_elem_size * src0->ne[0], weight_elem_size};
-
-    // size of one matrix is element_size * height * width.
-    size_t weight_stride = weight_elem_size * src0->ne[0] * src0->ne[1];
+    float weight_nb[] = {src0->ne[0] * weight_elem_size, weight_elem_size};
+    size_t weight_stride = src0->ne[1] * src0->ne[0] * weight_elem_size;
     size_t weight_size = weight_stride * src0->ne[2] * src0->ne[3];
 
     // scale stored at the end of weight. Also need transpose.
-    GGML_ASSERT(QK4_0 == QK8_0);
-    int64_t scale_ne[] = {src0->ne[1], src0->ne[0] / QK8_0};
     size_t scale_elem_size = sizeof(uint16_t);
-    size_t scale_nb[] = {src0->ne[0] / QK8_0 * scale_elem_size,
-                         scale_elem_size};
-    size_t scale_stride = scale_elem_size * src0->ne[0] * src0->ne[1] / QK8_0;
+    size_t scale_nb[] = {src0->ne[0] / QK8_0 * scale_elem_size, scale_elem_size};
+    size_t scale_stride = src0->ne[1] * src0->ne[0] / QK8_0 * scale_elem_size;
     char* scale_offset = (char*)src0->data + weight_size;
 
     // input
-    void* input_buffer;
     size_t input_elem_size = sizeof(uint16_t);
     int64_t input_ne[] = {src1->ne[0], src1->ne[1]};
-    size_t input_nb[] = {input_elem_size, input_elem_size * src1->ne[0]};
-    size_t input_stride = input_elem_size * src1->ne[0] * src1->ne[1];
-
+    size_t input_nb[] = {input_elem_size,  input_ne[0] * input_elem_size};
+    size_t input_stride = input_ne[0] * input_ne[1] * input_elem_size;
     ggml_cann_pool_alloc input_alloctor(ctx.pool());
+    void* input_buffer = src1->data;
+
+    // case in
     if (src1->type != GGML_TYPE_F16) {
         aclTensor* acl_src1_tensor = ggml_cann_create_tensor(src1);
-        input_alloctor.alloc(ggml_nelements(src1) * input_elem_size);
-        input_buffer = input_alloctor.get();
+        input_buffer = input_alloctor.alloc(ggml_nelements(src1) * input_elem_size);
 
         int64_t* input_cast_ne = src1->ne;
         size_t input_cast_nb[GGML_MAX_DIMS];
@@ -2550,88 +2642,139 @@ static void ggml_cann_mul_mat_quant(ggml_backend_cann_context& ctx,
         }
 
         aclTensor* acl_input_tensor = ggml_cann_create_tensor(
-            input_buffer, ACL_FLOAT16, input_elem_size, input_cast_ne,
-            input_cast_nb, GGML_MAX_DIMS);
+            input_buffer,
+            ACL_FLOAT16,
+            input_elem_size, input_cast_ne, input_cast_nb, GGML_MAX_DIMS);
         aclnn_cast(ctx, acl_src1_tensor, acl_input_tensor, ACL_FLOAT16);
+
         ACL_CHECK(aclDestroyTensor(acl_input_tensor));
         ACL_CHECK(aclDestroyTensor(acl_src1_tensor));
-    } else {
-        input_buffer = src1->data;
     }
 
     // output
     size_t output_elem_size = sizeof(uint16_t);
-    int64_t output_ne[] = {dst->ne[0], dst->ne[1]};
-    size_t output_nb[] = {output_elem_size, output_elem_size * dst->ne[0]};
-    ggml_cann_pool_alloc output_alloctor(
-        ctx.pool(), ggml_nelements(dst) * output_elem_size);
-    void* output_buffer = output_alloctor.get();
-    size_t output_stride = output_elem_size * dst->ne[0] * dst->ne[1];
+    size_t output_nb[] = {output_elem_size, dst->ne[0] * output_elem_size};
+    ggml_cann_pool_alloc output_allocator(ctx.pool());
+    void* output_buffer = output_allocator.alloc(ggml_nelements(dst) * output_elem_size);
+    size_t output_stride = dst->ne[0] * dst->ne[1] * output_elem_size;
 
     // aclnn
+    int64_t max_elem_size = 65535;
+    int64_t split_size = (src0->ne[1] / max_elem_size) + 1;
+    ggml_cann_pool_alloc workspace_allocator(ctx.pool());
+    aclOpExecutor* executor = nullptr;
     uint64_t workspaceSize = 0;
-    aclOpExecutor* executor;
     void* workspaceAddr = nullptr;
-
     for (int64_t n1 = 0; n1 < src1->ne[3]; n1++) {
         for (int64_t c1 = 0; c1 < src1->ne[2]; c1++) {
             int64_t n0 = n1 / (src1->ne[3] / src0->ne[3]);
             int64_t c0 = c1 / (src1->ne[2] / src0->ne[2]);
 
-            int64_t batch1 = n1 * src1->ne[2] + c1;
-            int64_t batch0 = n0 * src0->ne[2] + c0;
+            int64_t batch1 = (n1 * src1->ne[2]) + c1;
+            int64_t batch0 = (n0 * src0->ne[2]) + c0;
 
             aclTensor* acl_input_tensor = ggml_cann_create_tensor(
                 (char*)input_buffer + batch1 * input_stride, ACL_FLOAT16,
                 input_elem_size, input_ne, input_nb, 2);
+
+            // first split
+            int64_t weight_ne_offset = 0;
+            int64_t weight_ne[2] = {max_elem_size > src0->ne[1] ? src0->ne[1] : max_elem_size, src0->ne[0]};
+            int64_t scale_ne_offset = 0;
+            int64_t scale_ne[2] = {weight_ne[0], weight_ne[1] / QK8_0};
+            int64_t output_ne_offset = 0;
+            int64_t output_ne[2] = {weight_ne[0], dst->ne[1]};
+
             aclTensor* acl_weight_tensor = ggml_cann_create_tensor(
                 (char*)src0->data + batch0 * weight_stride,
-                ggml_cann_type_mapping(type), weight_elem_size, weight_ne,
-                weight_nb, 2);
+                ggml_cann_type_mapping(type),
+                weight_elem_size, weight_ne, weight_nb, 2,
+                ACL_FORMAT_ND, weight_ne_offset);
             aclTensor* acl_scale_tensor = ggml_cann_create_tensor(
-                scale_offset + batch0 * scale_stride, ACL_FLOAT16,
-                scale_elem_size, scale_ne, scale_nb, 2);
+                scale_offset + batch0 * scale_stride,
+                ACL_FLOAT16,
+                scale_elem_size, scale_ne, scale_nb, 2,
+                ACL_FORMAT_ND, scale_ne_offset);
             aclTensor* acl_output_tensor = ggml_cann_create_tensor(
-                (char*)output_buffer + batch1 * output_stride, ACL_FLOAT16,
-                output_elem_size, output_ne, output_nb, 2);
+                (char*)output_buffer + batch1 * output_stride,
+                ACL_FLOAT16,
+                output_elem_size, output_ne, output_nb, 2,
+                ACL_FORMAT_ND, output_ne_offset);
 
             ACL_CHECK(aclnnWeightQuantBatchMatmulV2GetWorkspaceSize(
-                acl_input_tensor, acl_weight_tensor, acl_scale_tensor, nullptr,
-                nullptr, nullptr, nullptr, QK8_0, acl_output_tensor,
-                &workspaceSize, &executor));
-
-            if (workspaceSize > 0 && workspaceAddr == nullptr) {
-                ggml_cann_pool_alloc workspace_allocator(ctx.pool(),
-                                                         workspaceSize);
-                workspaceAddr = workspace_allocator.get();
+                acl_input_tensor, acl_weight_tensor, acl_scale_tensor,
+                nullptr, nullptr, nullptr, nullptr, QK8_0,
+                acl_output_tensor, &workspaceSize, &executor));
+            if (workspaceAddr == nullptr) {
+                workspaceAddr = workspace_allocator.alloc(workspaceSize);
             }
-
             ACL_CHECK(aclnnWeightQuantBatchMatmulV2(
                 workspaceAddr, workspaceSize, executor, ctx.stream()));
 
-            ACL_CHECK(aclDestroyTensor(acl_input_tensor));
             ACL_CHECK(aclDestroyTensor(acl_weight_tensor));
             ACL_CHECK(aclDestroyTensor(acl_scale_tensor));
             ACL_CHECK(aclDestroyTensor(acl_output_tensor));
+
+            // other splits
+            for (int64_t split = 1; split < split_size; split++) {
+                weight_ne_offset += weight_elem_size * weight_ne[0] * weight_ne[1];
+                weight_ne[0] = max_elem_size * (split + 1) > src0->ne[1] ? src0->ne[1] - (max_elem_size * split) : max_elem_size;
+                scale_ne_offset += scale_elem_size * scale_ne[0] * scale_ne[1];
+                scale_ne[0] = weight_ne[0];
+                output_ne_offset += output_elem_size * output_ne[0] * output_ne[1];
+                output_ne[0] = weight_ne[0];
+
+                acl_weight_tensor = ggml_cann_create_tensor(
+                    (char*)src0->data + batch0 * weight_stride,
+                    ggml_cann_type_mapping(type),
+                    weight_elem_size, weight_ne, weight_nb, 2,
+                    ACL_FORMAT_ND, weight_ne_offset);
+                acl_scale_tensor = ggml_cann_create_tensor(
+                    scale_offset + batch0 * scale_stride,
+                    ACL_FLOAT16,
+                    scale_elem_size, scale_ne, scale_nb, 2,
+                    ACL_FORMAT_ND, scale_ne_offset);
+                acl_output_tensor = ggml_cann_create_tensor(
+                    (char*)output_buffer + batch1 * output_stride,
+                    ACL_FLOAT16,
+                    output_elem_size, output_ne, output_nb, 2,
+                    ACL_FORMAT_ND, output_ne_offset);
+
+                ACL_CHECK(aclnnWeightQuantBatchMatmulV2GetWorkspaceSize(
+                    acl_input_tensor, acl_weight_tensor, acl_scale_tensor,
+                    nullptr, nullptr, nullptr, nullptr, QK8_0,
+                    acl_output_tensor, &workspaceSize, &executor));
+                ACL_CHECK(aclnnWeightQuantBatchMatmulV2(
+                    workspaceAddr, workspaceSize, executor, ctx.stream()));
+
+                ACL_CHECK(aclDestroyTensor(acl_weight_tensor));
+                ACL_CHECK(aclDestroyTensor(acl_scale_tensor));
+                ACL_CHECK(aclDestroyTensor(acl_output_tensor));
+            }
+
+            ACL_CHECK(aclDestroyTensor(acl_input_tensor));
         }
     }
 
     // cast out
-    int64_t* output_cast_ne = dst->ne;
-    size_t output_cast_nb[GGML_MAX_DIMS];
-    output_cast_nb[0] = sizeof(uint16_t);
-    for (int i = 1; i < GGML_MAX_DIMS; i++) {
-        output_cast_nb[i] = output_cast_nb[i - 1] * output_cast_ne[i - 1];
-    }
+    if (dst->type != GGML_TYPE_F16) {
+        int64_t* output_cast_ne = dst->ne;
+        size_t output_cast_nb[GGML_MAX_DIMS];
+        output_cast_nb[0] = sizeof(uint16_t);
+        for (int i = 1; i < GGML_MAX_DIMS; i++) {
+            output_cast_nb[i] = output_cast_nb[i - 1] * output_cast_ne[i - 1];
+        }
 
-    aclTensor* acl_output_tensor =
-        ggml_cann_create_tensor(output_buffer, ACL_FLOAT16, output_elem_size,
-                                output_cast_ne, output_cast_nb, GGML_MAX_DIMS);
-    aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
-    aclnn_cast(ctx, acl_output_tensor, acl_dst_tensor, ACL_FLOAT);
+        aclTensor* acl_output_tensor = ggml_cann_create_tensor(
+            output_buffer,
+            ACL_FLOAT16,
+            output_elem_size, output_cast_ne, output_cast_nb, GGML_MAX_DIMS);
+        aclTensor* acl_dst_tensor = ggml_cann_create_tensor(dst);
+        aclnn_cast(ctx, acl_output_tensor, acl_dst_tensor, ggml_cann_type_mapping(dst->type));
 
-    ACL_CHECK(aclDestroyTensor(acl_output_tensor));
-    ACL_CHECK(aclDestroyTensor(acl_dst_tensor));
+        ACL_CHECK(aclDestroyTensor(acl_output_tensor));
+        ACL_CHECK(aclDestroyTensor(acl_dst_tensor));
+    }
 }
 
 void ggml_cann_mul_mat(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
index edfa496148ff22291b802e86fe5f17c3fe590c93..5164cb74ec92e277ac24f26bf223c3cf90a558ee 100644 (file)
@@ -211,17 +211,20 @@ struct ggml_cann_pool_alloc {
 struct ggml_backend_cann_context {
     int32_t device;                  /**< Device ID. */
     std::string name;                /**< Name of the device. */
+    std::string description;         /**< Description of the device. */
     aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
 
-    aclrtStream streams[GGML_CANN_MAX_STREAMS] = {
-        {nullptr}}; /**< Array of streams for the device. */
+    aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */
 
     /**
      * @brief Constructor for initializing the context with a given device.
      * @param device Device ID.
      */
     explicit ggml_backend_cann_context(int device)
-        : device(device), name("CANN" + std::to_string(device)) {}
+        : device(device), name("CANN" + std::to_string(device)) {
+        ggml_cann_set_device(device);
+        description = aclrtGetSocName();
+    }
 
     /**
      * @brief Destructor for cleaning up resources.
index 2ef5b590a8ecbd811dde92a66b3189da325dba34..c7a3419c796de0d3ac1a99c958b3136a051d3d9d 100644 (file)
@@ -122,6 +122,10 @@ static ggml_cann_device_info ggml_cann_init() {
         ACL_CHECK(aclrtMemGetAllocationGranularity(
             &prop, ACL_RT_MEM_ALLOC_GRANULARITY_RECOMMENDED,
             &info.devices[id].vmm_granularity));
+
+        size_t free, total;
+        ggml_backend_cann_get_device_memory(id, &free, &total);
+        info.devices[id].total_vram = free;
     }
 
     // TODO: add more device info later.
@@ -208,6 +212,11 @@ struct ggml_cann_pool_leg : public ggml_cann_pool {
      * @return A pointer to the allocated buffer.
      */
     void* alloc(size_t size, size_t* actual_size) override {
+        const size_t alignment = 128;
+        size = GGML_PAD(size, alignment);
+        if (size == 0) {
+            size = alignment;
+        }
 #ifdef DEBUG_CANN_MALLOC
         int nnz = 0;
         size_t max_size = 0;
@@ -246,13 +255,11 @@ struct ggml_cann_pool_leg : public ggml_cann_pool {
             return ptr;
         }
         void* ptr;
-        size_t look_ahead_size = (size_t)(1.05 * size);
-        look_ahead_size = 256 * ((look_ahead_size + 255) / 256);
         ggml_cann_set_device(device);
         ACL_CHECK(
-            aclrtMalloc(&ptr, look_ahead_size, ACL_MEM_MALLOC_HUGE_FIRST));
-        *actual_size = look_ahead_size;
-        pool_size += look_ahead_size;
+            aclrtMalloc(&ptr, size, ACL_MEM_MALLOC_HUGE_FIRST));
+        *actual_size = size;
+        pool_size += size;
 #ifdef DEBUG_CANN_MALLOC
         GGML_LOG_INFO(
             "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, "
@@ -296,7 +303,7 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
     /**
      * @brief The maximum size of the virtual memory pool (32 GB).
      */
-    static const size_t CANN_POOL_VMM_MAX_SIZE = 1ull << 35;  // 32 GB
+    size_t max_size;
 
     /**
      * @brief The device ID associated with this buffer pool.
@@ -341,7 +348,11 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
      */
     explicit ggml_cann_pool_vmm(int device)
         : device(device),
-          granularity(ggml_cann_info().devices[device].vmm_granularity) {}
+          granularity(ggml_cann_info().devices[device].vmm_granularity) {
+        auto dev = ggml_cann_info().devices[device];
+        granularity = dev.vmm_granularity;
+        max_size = dev.total_vram;
+    }
 
     /**
      * @brief Destructor to free all buffers in the virtual memory pool.
@@ -370,17 +381,19 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
         // round up the allocation size to the alignment to ensure that all
         // allocations are aligned for all data types
         const size_t alignment = 128;
-        size = alignment * ((size + alignment - 1) / alignment);
+        size = GGML_PAD(size, alignment);
+        if (size == 0) {
+            size = alignment;
+        }
 
         size_t avail = pool_size - pool_used;
 
         if (size > avail) {
             // round up to the next multiple of the granularity
             size_t reserve_size = size - avail;
-            reserve_size =
-                granularity * ((reserve_size + granularity - 1) / granularity);
+            reserve_size = GGML_PAD(reserve_size, granularity);
 
-            GGML_ASSERT(pool_size + reserve_size <= CANN_POOL_VMM_MAX_SIZE);
+            GGML_ASSERT(pool_size + reserve_size <= max_size);
 
             // allocate more physical memory
             aclrtPhysicalMemProp prop = {};
@@ -396,7 +409,7 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
             // reserve virtual address space (if not already reserved)
             if (pool_addr == 0) {
                 ACL_CHECK(aclrtReserveMemAddress(
-                    &pool_addr, CANN_POOL_VMM_MAX_SIZE, 0, NULL, 1));
+                    &pool_addr, max_size, 0, NULL, 1));
             }
 
             // map at the end of the pool
@@ -409,10 +422,11 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
             // add to the pool
             pool_size += reserve_size;
 
-            // GGML_LOG_INFO("cann pool[%d]: size increased to %llu MB (
-            // reserved %llu MB)\n",
-            //       device, (unsigned long long) (pool_size/1024/1024),
-            //       (unsigned long long) (reserve_size/1024/1024));
+#ifdef DEBUG_CANN_MALLOC
+             GGML_LOG_INFO("cann pool[%d]: size increased to %llu MB (reserved %llu MB)\n",
+                   device, (unsigned long long) (pool_size/1024/1024),
+                   (unsigned long long) (reserve_size/1024/1024));
+#endif
         }
 
         GGML_ASSERT(pool_addr != 0);
@@ -457,7 +471,6 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
  */
 std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
     int device) {
-    // return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_leg(device));
     return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device));
 }
 
@@ -1130,10 +1143,10 @@ ggml_backend_cann_buffer_type(int32_t device) {
     static bool ggml_backend_cann_buffer_type_initialized = false;
 
     if (!ggml_backend_cann_buffer_type_initialized) {
-        for (int32_t i = 0; i < GGML_CANN_MAX_DEVICES; i++) {
+        for (int32_t i = 0; i < ggml_cann_info().device_count; i++) {
             ggml_backend_cann_buffer_types[i] = {
                 /* .iface    = */ ggml_backend_cann_buffer_type_interface,
-                /* .device    = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device),
+                /* .device    = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), i),
                 /* .context  = */
                  new ggml_backend_cann_buffer_type_context{
                     i, "CANN" + std::to_string(i)},
@@ -1199,10 +1212,15 @@ static void * ggml_cann_host_malloc(size_t size) {
         return nullptr;
     }
 
+    const size_t alignment = 128;
+    size = GGML_PAD(size, alignment);
+    if (size == 0) {
+        size = alignment;
+    }
+
     void * hostPtr = nullptr;
     aclError err = aclrtMallocHost((void **) &hostPtr, size);
     if (err != ACL_SUCCESS) {
-
         GGML_LOG_WARN("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__,
                            size / 1024.0 / 1024.0, aclGetRecentErrMsg());
         return nullptr;