]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CANN: Optimize ggml_cann_set_device (#15935)
authorChenguang Li <redacted>
Wed, 17 Sep 2025 06:33:08 +0000 (14:33 +0800)
committerGitHub <redacted>
Wed, 17 Sep 2025 06:33:08 +0000 (14:33 +0800)
* CANN: Fix ggml_cann_set_device to avoid redundant device switches

- Added a check to skip aclrtSetDevice if the current device is already set.
- Prevents unnecessary context switches while keeping thread/device consistency.

* CANN: add device default id

ggml/src/ggml-cann/common.h
ggml/src/ggml-cann/ggml-cann.cpp

index c5fce8dc91f51c0316861c0a775b5ea50ff0b2b4..b707b843593c7305eaa5be3c918ff82706687cd2 100755 (executable)
@@ -526,7 +526,10 @@ struct ggml_backend_cann_context {
      */
     aclrtStream stream(int stream) {
         if (streams[stream] == nullptr) {
-            ggml_cann_set_device(device);
+            // If the device is not set here, destroying the stream later may cause a mismatch
+            // between the thread contexts where the stream was created and destroyed.
+            // However, I printed the device_id, thread_id, and stream, and they are all consistent.
+            ACL_CHECK(aclrtSetDevice(device));
             ACL_CHECK(aclrtCreateStream(&streams[stream]));
         }
         return streams[stream];
index 19a18a281dfcb827f8aa95289017f929df268fd3..56d82b4af3413c395bfd0d50b1a096bc94d91ca3 100755 (executable)
  * @param device The device ID to set.
  */
 void ggml_cann_set_device(const int32_t device) {
-    // TODO: uncomment these lines after empty context has fixed.
-    // int current_device;
-    // ACL_CHECK(aclrtGetDevice(&current_device));
+    int current_device = -1;
+    aclrtGetDevice(&current_device);
 
-    // if (device == current_device) {
-    //   return;
-    // }
+    if (device == current_device) {
+      return;
+    }
     ACL_CHECK(aclrtSetDevice(device));
 }
 
@@ -1729,6 +1728,7 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
             ggml_cann_get_rows(ctx, dst);
             break;
         case GGML_OP_SET_ROWS:
+            std::cout << "lcg GGML_OP_SET_ROWS"<< std::endl;
             ggml_cann_set_rows(ctx, dst);
             break;
         case GGML_OP_DUP: