]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
Add oneDNN primitive support (llama/9091)
authorluoyu-intel <redacted>
Thu, 22 Aug 2024 04:50:10 +0000 (12:50 +0800)
committerGeorgi Gerganov <redacted>
Tue, 27 Aug 2024 19:01:14 +0000 (22:01 +0300)
* add onednn

* add sycl_f16

* add dnnl stream

* add engine map

* use dnnl for intel only

* use fp16fp16fp16

* update doc

src/CMakeLists.txt
src/ggml-sycl.cpp
src/ggml-sycl/common.hpp
src/ggml-sycl/gemm.hpp [new file with mode: 0644]

index 425a2589502eba4874ddc69ebb5bf63e9240b279..3d5cb785f60d6d3d986164ddf26dd1fe73059887 100644 (file)
@@ -549,6 +549,13 @@ if (GGML_SYCL)
     file(GLOB   GGML_SOURCES_SYCL "ggml-sycl/*.cpp")
     list(APPEND GGML_SOURCES_SYCL "ggml-sycl.cpp")
 
+    find_package(DNNL)
+    message("-- DNNL found:"${DNNL_FOUND})
+    if (GGML_SYCL_TARGET STREQUAL "INTEL")
+        add_compile_definitions(GGML_SYCL_DNNL=${DNNL_FOUND})
+    else()
+        add_compile_definitions(GGML_SYCL_DNNL=0)
+    endif()
     if (WIN32)
         find_package(IntelSYCL REQUIRED)
         find_package(MKL REQUIRED)
@@ -561,6 +568,9 @@ if (GGML_SYCL)
             set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} -fsycl pthread m dl onemkl)
         endif()
     endif()
+    if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL")
+        list(APPEND GGML_EXTRA_LIBS DNNL::dnnl)
+    endif()
 endif()
 
 if (GGML_RPC)
index 94cd4b11052e1d0454400dc84aa1641d64d87a78..0d884f89a4e7bac03c179d14f2cb28b3712f3a53 100644 (file)
@@ -38,6 +38,7 @@
 
 #include "ggml-sycl/backend.hpp"
 #include "ggml-sycl/presets.hpp"
+#include "ggml-sycl/gemm.hpp"
 
 bool   ggml_sycl_loaded(void);
 void   ggml_sycl_free_data(struct ggml_tensor * tensor);
@@ -2482,6 +2483,7 @@ inline void ggml_sycl_op_mul_mat_sycl(
 
         const sycl::half alpha_f16 = 1.0f;
         const sycl::half beta_f16 = 0.0f;
+#if !GGML_SYCL_DNNL
         SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
             *stream, oneapi::mkl::transpose::trans,
             oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
@@ -2491,6 +2493,13 @@ inline void ggml_sycl_op_mul_mat_sycl(
             dpct::library_data_t::real_half)));
         const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
         to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
+#else
+        auto dnnl_stream = ctx.stream_dnnl(stream);
+        DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
+            src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>());
+        const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
+        to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
+#endif
     }
     else {
         // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n");
@@ -2513,13 +2522,18 @@ inline void ggml_sycl_op_mul_mat_sycl(
 
         const float alpha = 1.0f;
         const float beta = 0.0f;
-
+#if !GGML_SYCL_DNNL
         SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
             *stream, oneapi::mkl::transpose::trans,
             oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
             dpct::get_value(&alpha, *stream), src0_ddf_i, ne00,
             src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
             dst_dd_i, ldc)));
+#else
+        auto dnnl_stream = ctx.stream_dnnl(stream);
+         DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
+            src0_ddf_i, DnnlGemmWrapper::to_dt<float>(), dst_dd_i, DnnlGemmWrapper::to_dt<float>());
+#endif
     }
     (void) dst;
     (void) src1_ddq_i;
index 78cd682ad4057019ddfde901236af1e0b4e1019c..05947ccb746f26d6ab8f51962e4e8871e3da0efc 100644 (file)
 #include "dpct/helper.hpp"
 #include "ggml-sycl.h"
 #include "presets.hpp"
+#if GGML_SYCL_DNNL
+#include "dnnl.hpp"
+#include "dnnl_sycl.hpp"
+#endif
 
 #define GGML_COMMON_DECL_SYCL
 #define GGML_COMMON_IMPL_SYCL
@@ -277,6 +281,52 @@ struct ggml_backend_sycl_context {
         return stream(device, 0);
     }
 
+#if GGML_SYCL_DNNL
+    dnnl::engine make_engine(sycl::queue* q) {
+        // Get the device associated with the queue
+        sycl::device dev = q->get_device();
+        // Get the context associated with the queue
+        sycl::context ctx = q->get_context();
+        const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
+        return eng;
+    }
+
+    std::unordered_map<sycl::queue*, dnnl::stream> stream_map;
+    std::unordered_map<sycl::queue*, dnnl::engine> engine_map;
+    dnnl::stream stream_dnnl(int device, int _stream) {
+        auto q = stream(device, _stream);
+        return stream_dnnl(q);
+    }
+    dnnl::engine engine_dnnl(sycl::queue* qptr) {
+        auto it = engine_map.find(qptr);
+        if (it == engine_map.end()) {
+            auto eng = make_engine(qptr);
+            engine_map[qptr] = eng;
+            return eng;
+        }
+        else
+        {
+            return it->second;
+        }
+    }
+    dnnl::stream stream_dnnl(sycl::queue* qptr) {
+        auto it = stream_map.find(qptr);
+        if (it == stream_map.end()) {
+            auto eng = engine_dnnl(qptr);
+            auto stream = dnnl::sycl_interop::make_stream(eng, *qptr);
+            stream_map[qptr] = stream;
+            return stream;
+        }
+        else
+        {
+            return it->second;
+        }
+    }
+    dnnl::stream stream_dnnl() {
+        return stream_dnnl(device, 0);
+    }
+#endif
+
     // pool
     std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
 
diff --git a/src/ggml-sycl/gemm.hpp b/src/ggml-sycl/gemm.hpp
new file mode 100644 (file)
index 0000000..2ad9b36
--- /dev/null
@@ -0,0 +1,101 @@
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#ifndef GGML_SYCL_GEMM_HPP
+#define GGML_SYCL_GEMM_HPP
+
+#include <fstream>
+#include <iostream>
+
+#include "ggml-sycl.h"
+
+#if GGML_SYCL_DNNL
+
+#include "dnnl.hpp"
+#include "dnnl_sycl.hpp"
+
+class DnnlGemmWrapper {
+public:
+    using dt = dnnl::memory::data_type;
+    using tag = dnnl::memory::format_tag;
+
+    template<typename T>
+    static constexpr dt to_dt() {
+        if constexpr (std::is_same_v<T, float>) return dt::f32;
+        else if constexpr (std::is_same_v<T, sycl::half>) return dt::f16;
+        else static_assert(0);
+    }
+
+    static inline void row_gemm(sycl::queue& q, bool a_trans,
+        bool b_trans, int m, int n, int k,
+        const void* a, dt at, const void* b, dt bt, void* c, dt ct)
+    {
+        // Get the device associated with the queue
+        sycl::device dev = q.get_device();
+        // Get the context associated with the queue
+        sycl::context ctx = q.get_context();
+        const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
+        const dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q);
+        dnnl::memory::dims a_dims = { m, k };
+        dnnl::memory::dims b_dims = { k, n };
+        dnnl::memory::dims c_dims = { m, n };
+        const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
+        const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
+        const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
+        auto a_mem = dnnl::memory(a_in_md, eng, (void*)a);
+        auto b_mem = dnnl::memory(b_in_md, eng, (void*)b);
+        auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
+        auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
+
+        // Create the primitive.
+        auto matmul_prim = dnnl::matmul(matmul_pd);
+        // Primitive arguments.
+        std::unordered_map<int, dnnl::memory> matmul_args;
+        matmul_args.insert({ DNNL_ARG_SRC, a_mem });
+        matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
+        matmul_args.insert({ DNNL_ARG_DST, c_mem });
+
+        matmul_prim.execute(stream, matmul_args);
+    }
+
+
+    static inline void row_gemm(const dnnl::stream& stream, bool a_trans,
+        bool b_trans, int m, int n, int k,
+        const void* a, dt at, const void* b, dt bt, void* c, dt ct)
+    {
+        auto const eng = stream.get_engine();
+        dnnl::memory::dims a_dims = { m, k };
+        dnnl::memory::dims b_dims = { k, n };
+        dnnl::memory::dims c_dims = { m, n };
+        const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
+        const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
+        const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
+        auto a_mem = dnnl::memory(a_in_md, eng, (void*)a);
+        auto b_mem = dnnl::memory(b_in_md, eng, (void*)b);
+        auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
+        auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
+
+        // Create the primitive.
+        auto matmul_prim = dnnl::matmul(matmul_pd);
+        // Primitive arguments.
+        std::unordered_map<int, dnnl::memory> matmul_args;
+        matmul_args.insert({ DNNL_ARG_SRC, a_mem });
+        matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
+        matmul_args.insert({ DNNL_ARG_DST, c_mem });
+
+        matmul_prim.execute(stream, matmul_args);
+    }
+};
+
+#endif
+
+#endif // GGML_SYCL_GEMM_HPP