]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
use the correct SYCL context for host USM allocations (llama/7777)
authorBen Ashbaugh <redacted>
Mon, 10 Jun 2024 09:21:31 +0000 (02:21 -0700)
committerGeorgi Gerganov <redacted>
Sat, 15 Jun 2024 19:05:47 +0000 (22:05 +0300)
Signed-off-by: Ben Ashbaugh <redacted>
src/ggml-sycl.cpp

index 0a645b2e1db8023d8e8ad47fb50dc92919a40d86..42fc0df203537b03eb32c63a6025199450c81691 100644 (file)
@@ -13089,10 +13089,12 @@ void *ggml_sycl_host_malloc(size_t size) try {
         return nullptr;
     }
 
+    ggml_sycl_set_device(g_main_device);
+    dpct::queue_ptr main_stream = g_syclStreams[g_main_device][0];
+
     void * ptr = nullptr;
-    //allow to use dpct::get_in_order_queue() for host malloc
     dpct::err0 err = CHECK_TRY_ERROR(
-        ptr = (void *)sycl::malloc_host(size, dpct::get_in_order_queue()));
+        ptr = (void *)sycl::malloc_host(size, *main_stream));
 
     if (err != 0) {
         // clear the error
@@ -13113,8 +13115,9 @@ catch (sycl::exception const &exc) {
 }
 
 void ggml_sycl_host_free(void *ptr) try {
-    //allow to use dpct::get_in_order_queue() for host malloc
-    SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, dpct::get_in_order_queue())));
+    ggml_sycl_set_device(g_main_device);
+    dpct::queue_ptr main_stream = g_syclStreams[g_main_device][0];
+    SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *main_stream)));
 }
 catch (sycl::exception const &exc) {
   std::cerr << exc.what() << "Exception caught at file:" << __FILE__