]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
cann: Fix Multi-NPU execution error (llama/8710)
authorwangshuai09 <redacted>
Sat, 27 Jul 2024 08:36:44 +0000 (16:36 +0800)
committerGeorgi Gerganov <redacted>
Sat, 27 Jul 2024 15:26:12 +0000 (18:26 +0300)
* cann: fix multi-npu exec error

* cann: update comment  for ggml_backend_cann_supports_buft

ggml/src/ggml-cann.cpp

index ad5feea05c8ce863846e12044d757254fb5f32f2..461febcc03a8969b29f6a6d6e669efb2763ad6b3 100644 (file)
@@ -1559,23 +1559,18 @@ GGML_CALL static bool ggml_backend_cann_cpy_tensor_async(
             return false;
         }
 
+        // need open both directions for memcpyasync between devices.
+        ggml_cann_set_device(cann_ctx_dst->device);
+        ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_src->device, 0));
         ggml_cann_set_device(cann_ctx_src->device);
         ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0));
+
         ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
                                    ACL_MEMCPY_DEVICE_TO_DEVICE,
-                                   cann_ctx_dst->stream()));
-
-        // record event on src stream
-        if (!cann_ctx_src->copy_event) {
-            ACL_CHECK(aclrtCreateEvent(&cann_ctx_src->copy_event));
-        }
-
-        ACL_CHECK(
-            aclrtRecordEvent(cann_ctx_src->copy_event, cann_ctx_src->stream()));
+                                   cann_ctx_src->stream()));
 
-        // wait on dst stream for the copy to complete
-        ACL_CHECK(aclrtStreamWaitEvent(cann_ctx_dst->stream(),
-                                       cann_ctx_src->copy_event));
+        //TODO: workaround for Event didn`t work here.
+        aclrtSynchronizeStream(cann_ctx_src->stream());
     } else {
         // src and dst are on the same backend
         ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
@@ -1763,8 +1758,8 @@ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
  *
  * This function determines whether the CANN backend supports the given backend
  * buffer type by comparing the device context of the backend and buffer type.
- * It returns true if the device associated with the buffer type matches the
- * device associated with the backend.
+ * It returns true if the devices are same between the backend context and
+ * buffer type context.
  *
  * @param backend Pointer to the CANN backend.
  * @param buft Pointer to the backend buffer type to check.
@@ -1773,9 +1768,14 @@ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
  */
 GGML_CALL static bool ggml_backend_cann_supports_buft(
     ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
-    return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
-
-    GGML_UNUSED(backend);
+    if (ggml_backend_buft_is_cann(buft)) {
+        ggml_backend_cann_context * cann_ctx =
+                        (ggml_backend_cann_context *)backend->context;
+        ggml_backend_cann_buffer_type_context * buft_ctx =
+                        (ggml_backend_cann_buffer_type_context *)buft->context;
+        return buft_ctx->device == cann_ctx->device;
+    }
+    return false;
 }
 
 /**