]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
sycl: cleanup oneDNN related code (llama/12097)
authorSvetlozar Georgiev <redacted>
Fri, 21 Mar 2025 02:15:56 +0000 (02:15 +0000)
committerGeorgi Gerganov <redacted>
Thu, 27 Mar 2025 09:06:03 +0000 (11:06 +0200)
ggml/src/ggml-sycl/CMakeLists.txt
ggml/src/ggml-sycl/common.hpp
ggml/src/ggml-sycl/gemm.hpp
ggml/src/ggml-sycl/ggml-sycl.cpp

index 271413ca414bfaa362ae565ac6be91f107b94c07..f713fbe46e0125ed1cfe63300e74060ab356fe55 100644 (file)
@@ -23,6 +23,38 @@ ggml_add_backend_library(ggml-sycl
                          ../../include/ggml-sycl.h
                         )
 
+find_package(DNNL)
+set(GGML_SYCL_DNNL 0)
+if(DNNL_FOUND)
+    if (DEFINED ENV{ONEAPI_ROOT} AND NOT DEFINED DNNL_GPU_VENDOR)
+        # Assuming oneDNN packaged with oneapi release is used which
+        # supports only intel target
+        set(DNNL_GPU_VENDOR "INTEL")
+        if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL")
+            message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target")
+        endif()
+    endif()
+
+    # Verify oneDNN was compiled for the same target as llama
+    if("${GGML_SYCL_TARGET}" STREQUAL "${DNNL_GPU_VENDOR}")
+        target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
+        set(GGML_SYCL_DNNL 1)
+        get_target_property(CONFIGS DNNL::dnnl IMPORTED_CONFIGURATIONS)
+        foreach(CONFIG ${CONFIGS})
+            get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG})
+            message(STATUS "Found oneDNN: ${DNNL_LIB}")
+        endforeach()
+    else()
+        message(WARNING
+            "oneDNN must be compiled for the same target as llama.cpp.
+             llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}.
+             Disabling oneDNN support.")
+    endif()
+else()
+    message(STATUS "oneDNN not found, disabling oneDNN support")
+endif()
+target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL})
+
 if (GGML_SYCL_F16)
     if (GGML_SYCL_TARGET STREQUAL "AMD")
         message(WARNING "AMD target does not entirely support FP16 in the SYCL backend.")
@@ -48,18 +80,6 @@ file(GLOB   GGML_HEADERS_SYCL "*.hpp")
 file(GLOB   GGML_SOURCES_SYCL "*.cpp")
 target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL})
 
-find_package(DNNL)
-message("-- DNNL found:" ${DNNL_FOUND})
-
-if (GGML_SYCL_TARGET STREQUAL "INTEL")
-    add_compile_definitions(GGML_SYCL_DNNL=${DNNL_FOUND})
-else()
-    add_compile_definitions(GGML_SYCL_DNNL=0)
-endif()
-
-if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL")
-    target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
-endif()
 
 if (WIN32)
     find_package(IntelSYCL REQUIRED)
index 7cc5e14f9ab225b55802e10cbaa7eea51dae9578..27b447ce30d18036ef921b03f046527a599add9c 100644 (file)
@@ -170,7 +170,6 @@ static size_t g_scratch_offset = 0;
 int get_current_device_id();
 
 inline dpct::err0 ggml_sycl_set_device(const int device) try {
-
   int current_device_id;
   SYCL_CHECK(CHECK_TRY_ERROR(current_device_id = get_current_device_id()));
 
@@ -242,6 +241,14 @@ struct ggml_sycl_pool_alloc {
         }
     }
 
+    T * realloc(size_t size) {
+        GGML_ASSERT(pool != nullptr);
+        if (ptr)
+            pool->free(ptr, actual_size);
+        ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size);
+        return ptr;
+    }
+
     // size is in number of elements
     T * alloc(size_t size) {
         GGML_ASSERT(pool != nullptr);
@@ -371,10 +378,29 @@ struct ggml_backend_sycl_context {
     dnnl::stream stream_dnnl() {
         return stream_dnnl(device, 0);
     }
+    dnnl::memory get_scratchpad_mem(const dnnl::memory::desc & scratchpad_md,
+                                    const dnnl::engine & eng, const queue_ptr q) {
+        ggml_sycl_pool_alloc<uint8_t> * pool;
+        auto it = scratchpad_map.find(q);
+        if (it == scratchpad_map.end()) {
+            scratchpad_map[q] = std::make_unique<ggml_sycl_pool_alloc<uint8_t>>(this->pool());
+            pool = scratchpad_map[q].get();
+        } else {
+            pool = it->second.get();
+        }
+
+        size_t scratchpad_size = scratchpad_md.get_size();
+        if (scratchpad_size > pool->actual_size) {
+            pool->realloc(scratchpad_size);
+        }
+        void * mem_ptr = pool->get();
+        return dnnl::memory(scratchpad_md, eng, mem_ptr);
+    }
 #endif
 
     // pool
     std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
+    std::unordered_map<sycl::queue *, std::unique_ptr<ggml_sycl_pool_alloc<uint8_t>>> scratchpad_map;
 
     std::unique_ptr<ggml_sycl_pool> host_pools[GGML_SYCL_MAX_DEVICES];
 
index 3f0f34ad603f59fe9aa10b8905045114a1cfe249..4ebbb5b66fb47a4134d0f56380caa281ac2489ed 100644 (file)
@@ -13,9 +13,6 @@
 #ifndef GGML_SYCL_GEMM_HPP
 #define GGML_SYCL_GEMM_HPP
 
-#include <fstream>
-#include <iostream>
-
 #include "ggml-sycl.h"
 
 #if GGML_SYCL_DNNL
@@ -35,62 +32,34 @@ public:
         else static_assert(0);
     }
 
-    static inline void row_gemm(sycl::queue& q, bool a_trans,
-        bool b_trans, int m, int n, int k,
-        const void* a, dt at, const void* b, dt bt, void* c, dt ct)
-    {
-        // Get the device associated with the queue
-        sycl::device dev = q.get_device();
-        // Get the context associated with the queue
-        sycl::context ctx = q.get_context();
-        const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
-        const dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q);
+    static inline void row_gemm(ggml_backend_sycl_context & ctx, bool a_trans, bool b_trans, int m, int n, int k,
+                                const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
+        auto stream = ctx.stream_dnnl(q);
+        auto eng = ctx.engine_dnnl(q);
         dnnl::memory::dims a_dims = { m, k };
         dnnl::memory::dims b_dims = { k, n };
         dnnl::memory::dims c_dims = { m, n };
         const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
         const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
-        const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
-        auto a_mem = dnnl::memory(a_in_md, eng, const_cast<void*>(a));
-        auto b_mem = dnnl::memory(b_in_md, eng, const_cast<void*>(b));
-        auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
-        auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
+        const auto c_md    = dnnl::memory::desc(c_dims, ct, tag::ab);
 
-        // Create the primitive.
-        auto matmul_prim = dnnl::matmul(matmul_pd);
-        // Primitive arguments.
-        std::unordered_map<int, dnnl::memory> matmul_args;
-        matmul_args.insert({ DNNL_ARG_SRC, a_mem });
-        matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
-        matmul_args.insert({ DNNL_ARG_DST, c_mem });
+        dnnl::primitive_attr primitive_attr;
+        primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
 
-        matmul_prim.execute(stream, matmul_args);
-    }
-
-
-    static inline void row_gemm(const dnnl::stream& stream, bool a_trans,
-        bool b_trans, int m, int n, int k,
-        const void* a, dt at, const void* b, dt bt, void* c, dt ct)
-    {
-        auto const eng = stream.get_engine();
-        dnnl::memory::dims a_dims = { m, k };
-        dnnl::memory::dims b_dims = { k, n };
-        dnnl::memory::dims c_dims = { m, n };
-        const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
-        const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
-        const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
         auto a_mem = dnnl::memory(a_in_md, eng, const_cast<void*>(a));
         auto b_mem = dnnl::memory(b_in_md, eng, const_cast<void*>(b));
-        auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
+        auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md, primitive_attr);
         auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
 
-        // Create the primitive.
+        auto scratchpad_md = matmul_pd.scratchpad_desc();
+        auto scratchpad_mem = ctx.get_scratchpad_mem(scratchpad_md, eng, q);
         auto matmul_prim = dnnl::matmul(matmul_pd);
-        // Primitive arguments.
+
         std::unordered_map<int, dnnl::memory> matmul_args;
         matmul_args.insert({ DNNL_ARG_SRC, a_mem });
         matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
         matmul_args.insert({ DNNL_ARG_DST, c_mem });
+        matmul_args.insert({ DNNL_ARG_SCRATCHPAD, scratchpad_mem });
 
         matmul_prim.execute(stream, matmul_args);
     }
index 360e3f166c218176a22d44024c121939307ae097..f4b68333e059b473b426681c158d478768419e3d 100644 (file)
@@ -2058,9 +2058,9 @@ inline void ggml_sycl_op_mul_mat_sycl(
         const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
         to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
 #else
-        auto dnnl_stream = ctx.stream_dnnl(stream);
-        DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
-            src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>());
+        DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ptr,
+                                  DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
+                                  dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
         const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
         to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
 #endif
@@ -2099,9 +2099,9 @@ inline void ggml_sycl_op_mul_mat_sycl(
             dst_dd_i, ldc)));
 #    endif
 #else
-        auto dnnl_stream = ctx.stream_dnnl(stream);
-         DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
-            src0_ddf_i, DnnlGemmWrapper::to_dt<float>(), dst_dd_i, DnnlGemmWrapper::to_dt<float>());
+        DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i,
+                                  DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
+                                  dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
 #endif
     }
     GGML_UNUSED(dst);