]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CANN: Refactor ND to NZ workspace to be per-device (llama/15763)
authorChenguang Li <redacted>
Thu, 4 Sep 2025 12:20:14 +0000 (20:20 +0800)
committerGeorgi Gerganov <redacted>
Fri, 5 Sep 2025 09:54:14 +0000 (12:54 +0300)
* CANN:Refactor ND to NZ workspace to be per-device in Ascend backend

- Replaced the previous single global ND→NZ workspace with a per-device
  cache using unordered_map keyed by device ID.
- Functions `release_nz_workspace`, `relloc_nz_workspace`, and
  `get_nz_workspace` now manage workspace independently for each device,
  preventing memory conflicts in multi-device / pipeline parallel scenarios.
- This change fixes potential precision issues caused by workspace
  overwrites when multiple devices perform ND→NZ conversions concurrently.

Co-authored-by: hipudding <redacted>
* refactor

Signed-off-by: noemotiovon <redacted>
* rename

Signed-off-by: noemotiovon <redacted>
* fix review comments

Signed-off-by: noemotiovon <redacted>
---------

Signed-off-by: noemotiovon <redacted>
Co-authored-by: hipudding <redacted>
src/ggml-cann/ggml-cann.cpp

index 1aa2913a61788d6f1f87df789795db6521c923c7..756ad8dfad7af33feb5ef3e13c7df00c952e2337 100755 (executable)
@@ -1116,30 +1116,65 @@ static enum ggml_status ggml_backend_cann_buffer_init_tensor(
     return GGML_STATUS_SUCCESS;
 }
 
-// ND to NZ Workspace Cache Management. Thread-safety: Not guaranteed
-namespace {
-    void* g_nz_workspace = nullptr;
-    size_t g_nz_workspace_allocated = 0;
-
-    void release_nz_workspace() {
-        if (g_nz_workspace) {
-            aclrtFree(g_nz_workspace);
-            g_nz_workspace = nullptr;
-            g_nz_workspace_allocated = 0;
+/**
+ * @brief Workspace for caching NZ buffers per device.
+ *
+ * This struct manages a device buffer used in NZ computations. It supports
+ * allocation, reallocation, and clearing of cached memory. The struct is
+ * designed to be used with a global array, one per device.
+ */
+struct ggml_cann_nz_workspace {
+    void*  ptr;       // Pointer to allocated device buffer
+    size_t allocated; // Size of currently allocated buffer in bytes
+
+    /**
+     * @brief Constructor. Initializes the workspace with no allocated memory.
+     */
+    ggml_cann_nz_workspace() : ptr(nullptr), allocated(0) {}
+
+    /**
+     * @brief Free cached memory and reset the workspace.
+     *
+     * If a buffer has been allocated, this function releases it using
+     * aclrtFree and resets internal state.
+     */
+    void clear() {
+        if (ptr) {
+            ACL_CHECK(aclrtFree(ptr));
+            ptr = nullptr;
+            allocated = 0;
         }
     }
 
-    void relloc_nz_workspace(size_t new_size) {
-        if (new_size > g_nz_workspace_allocated) {
-        if (g_nz_workspace) {
-            aclrtFree(g_nz_workspace);
-            g_nz_workspace = nullptr;
+    /**
+     * @brief Allocate or reallocate the workspace buffer.
+     *
+     * If the requested size is larger than the currently allocated size,
+     * the old buffer will be freed and a new buffer of the requested size
+     * will be allocated on the device.
+     *
+     * @param new_size Size in bytes to allocate for the workspace.
+     */
+    void realloc(size_t new_size) {
+        if (new_size > allocated) {
+            clear();
+            ACL_CHECK(aclrtMalloc(&ptr, new_size, ACL_MEM_MALLOC_HUGE_FIRST));
+            allocated = new_size;
         }
-        ACL_CHECK(aclrtMalloc(&g_nz_workspace, new_size, ACL_MEM_MALLOC_HUGE_FIRST));
-        g_nz_workspace_allocated = new_size;
-    }
     }
-}
+
+    /**
+     * @brief Get the device buffer pointer.
+     *
+     * @return Pointer to the allocated buffer, or nullptr if not allocated.
+     */
+    void* get() const { return ptr; }
+};
+
+/**
+ * @brief Global array of NZ workspaces, one per device.
+ */
+static ggml_cann_nz_workspace g_nz_workspaces[GGML_CANN_MAX_DEVICES];
 
 /**
  * @brief Convert tensor weights to NZ format using Ascend CANN API.
@@ -1149,13 +1184,13 @@ namespace {
  * improve performance on certain hardware.
  *
  * @param tensor Pointer to the input ggml_tensor containing the weights.
- * @param data Pointer to the raw data buffer for the tensor weights.
  * @param offset Byte offset within the tensor data buffer where weights start.
+ * @param device device id.
  *
  * @note The workspace buffer used in this function is managed globally and reused
  *       across calls. This reduces overhead from repeated memory allocation and deallocation.
  */
-static void weight_format_to_nz(ggml_tensor *tensor, size_t offset) {
+static void weight_format_to_nz(ggml_tensor *tensor, size_t offset, int device) {
     aclTensor* weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne,
                                     tensor->nb, 2, ACL_FORMAT_ND, offset);
     uint64_t workspaceSize = 0;
@@ -1165,7 +1200,9 @@ static void weight_format_to_nz(ggml_tensor *tensor, size_t offset) {
     ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed,
                                                     &workspaceSize, &executor));
     // Avoid frequent malloc/free of the workspace.
-    relloc_nz_workspace(workspaceSize);
+    g_nz_workspaces[device].realloc(workspaceSize);
+
+    void* g_nz_workspace = g_nz_workspaces[device].get();
 
     ACL_CHECK(aclnnTransMatmulWeight(g_nz_workspace, workspaceSize, executor, nullptr));
     ACL_CHECK(aclDestroyTensor(weightTransposed));
@@ -1203,7 +1240,7 @@ static void ggml_backend_cann_buffer_set_tensor(
         if (weight_to_nz && is_matmul_weight((const ggml_tensor*)tensor)) {
             GGML_ASSERT(tensor->ne[2] == 1);
             GGML_ASSERT(tensor->ne[3] == 1);
-            weight_format_to_nz(tensor, offset);
+            weight_format_to_nz(tensor, offset, ctx->device);
         }
     } else {
         void *transform_buffer = malloc(size);
@@ -2262,7 +2299,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(
     ggml_backend_cann_context* cann_ctx =
         (ggml_backend_cann_context*)backend->context;
     ggml_cann_set_device(cann_ctx->device);
-    release_nz_workspace();
+    g_nz_workspaces[cann_ctx->device].clear();
 
 #ifdef USE_ACL_GRAPH
     bool use_cann_graph = true;