]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CANN: Improve device ID handling and aclnnArange checks (llama/16752)
authorChenguang Li <redacted>
Tue, 28 Oct 2025 02:54:53 +0000 (10:54 +0800)
committerGeorgi Gerganov <redacted>
Sat, 1 Nov 2025 07:41:35 +0000 (09:41 +0200)
* cann: improve device ID handling and aclnnArange checks

- Stop relying on CANN's internal device ID retrieval; use a global variable instead.
- Enforce stricter dimension validation in aclnnArange for better compatibility across CANN versions.

* cann: use thread local var

src/ggml-cann/aclnn_ops.cpp
src/ggml-cann/ggml-cann.cpp

index f030ea0136a958a2010b08a28ab1769fc2f0d17b..5df6dc96a3b2e9c013c1cb8faf0503ebb039c027 100644 (file)
@@ -2234,7 +2234,7 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
                               ACL_MEM_MALLOC_HUGE_FIRST));
 
         acl_theta_scale_tensor = ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
-                                                         theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
+                                                         theta_scale_ne, theta_scale_nb, 1);
 
         float start      = 0;
         float step       = 1;
@@ -2251,7 +2251,7 @@ static void aclnn_cache_init(ggml_backend_cann_context & ctx,
             yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
             void * yarn_ramp_buffer = yarn_ramp_allocator.get();
             acl_yarn_ramp_tensor   = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float), theta_scale_ne,
-                                                             theta_scale_nb, GGML_MAX_DIMS);
+                                                             theta_scale_nb, 1);
             float       zero_value = 0, one_value = 1;
             float       denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
             aclScalar * low              = aclCreateScalar(&corr_dims[0], aclDataType::ACL_FLOAT);
index 8bd5449f1f75fd91389fd2a1784c2fdcae20dff3..51345742ee59eb2d5a5e14a1ab74d59b362b4e0d 100644 (file)
     GGML_ABORT("CANN error");
 }
 
+// Thread-local variable to record the current device of this thread.
+thread_local int g_current_cann_device = -1;
+
 /**
- * @brief Sets the device to be used by CANN.
+ * @brief Set the CANN device to be used.
  *
- * @param device The device ID to set.
+ * @param device The target device ID to set.
  */
 void ggml_cann_set_device(const int32_t device) {
-    int current_device = -1;
-    aclrtGetDevice(&current_device);
+    // int current_device = -1;
+    // Note: In some CANN versions, if no device has been set yet,
+    //       aclrtGetDevice(&current_device) may return 0 by default.
+    // aclrtGetDevice(&current_device);
 
-    if (device == current_device) {
+    // If the current device is already the target one, no need to switch.
+    if (device == g_current_cann_device) {
         return;
     }
+
+    // Switch to the new device.
     ACL_CHECK(aclrtSetDevice(device));
+
+    // Update the global device record.
+    g_current_cann_device = device;
 }
 
 /**