#include "ggml-sycl/backend.hpp"
#include "ggml-sycl/presets.hpp"
+#include "ggml-sycl/gemm.hpp"
bool ggml_sycl_loaded(void);
void ggml_sycl_free_data(struct ggml_tensor * tensor);
const sycl::half alpha_f16 = 1.0f;
const sycl::half beta_f16 = 0.0f;
+#if !GGML_SYCL_DNNL
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
*stream, oneapi::mkl::transpose::trans,
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
dpct::library_data_t::real_half)));
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
+#else
+ auto dnnl_stream = ctx.stream_dnnl(stream);
+ DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
+ src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(), dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>());
+ const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
+ to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
+#endif
}
else {
// GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n");
const float alpha = 1.0f;
const float beta = 0.0f;
-
+#if !GGML_SYCL_DNNL
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
*stream, oneapi::mkl::transpose::trans,
oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
dpct::get_value(&alpha, *stream), src0_ddf_i, ne00,
src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
dst_dd_i, ldc)));
+#else
+ auto dnnl_stream = ctx.stream_dnnl(stream);
+ DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt<float>(),
+ src0_ddf_i, DnnlGemmWrapper::to_dt<float>(), dst_dd_i, DnnlGemmWrapper::to_dt<float>());
+#endif
}
(void) dst;
(void) src1_ddq_i;
#include "dpct/helper.hpp"
#include "ggml-sycl.h"
#include "presets.hpp"
+#if GGML_SYCL_DNNL
+#include "dnnl.hpp"
+#include "dnnl_sycl.hpp"
+#endif
#define GGML_COMMON_DECL_SYCL
#define GGML_COMMON_IMPL_SYCL
return stream(device, 0);
}
+#if GGML_SYCL_DNNL
+ dnnl::engine make_engine(sycl::queue* q) {
+ // Get the device associated with the queue
+ sycl::device dev = q->get_device();
+ // Get the context associated with the queue
+ sycl::context ctx = q->get_context();
+ const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
+ return eng;
+ }
+
+ std::unordered_map<sycl::queue*, dnnl::stream> stream_map;
+ std::unordered_map<sycl::queue*, dnnl::engine> engine_map;
+ dnnl::stream stream_dnnl(int device, int _stream) {
+ auto q = stream(device, _stream);
+ return stream_dnnl(q);
+ }
+ dnnl::engine engine_dnnl(sycl::queue* qptr) {
+ auto it = engine_map.find(qptr);
+ if (it == engine_map.end()) {
+ auto eng = make_engine(qptr);
+ engine_map[qptr] = eng;
+ return eng;
+ }
+ else
+ {
+ return it->second;
+ }
+ }
+ dnnl::stream stream_dnnl(sycl::queue* qptr) {
+ auto it = stream_map.find(qptr);
+ if (it == stream_map.end()) {
+ auto eng = engine_dnnl(qptr);
+ auto stream = dnnl::sycl_interop::make_stream(eng, *qptr);
+ stream_map[qptr] = stream;
+ return stream;
+ }
+ else
+ {
+ return it->second;
+ }
+ }
+ dnnl::stream stream_dnnl() {
+ return stream_dnnl(device, 0);
+ }
+#endif
+
// pool
std::unique_ptr<ggml_sycl_pool> pools[GGML_SYCL_MAX_DEVICES];
--- /dev/null
+//
+// MIT license
+// Copyright (C) 2024 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#ifndef GGML_SYCL_GEMM_HPP
+#define GGML_SYCL_GEMM_HPP
+
+#include <fstream>
+#include <iostream>
+
+#include "ggml-sycl.h"
+
+#if GGML_SYCL_DNNL
+
+#include "dnnl.hpp"
+#include "dnnl_sycl.hpp"
+
+class DnnlGemmWrapper {
+public:
+ using dt = dnnl::memory::data_type;
+ using tag = dnnl::memory::format_tag;
+
+ template<typename T>
+ static constexpr dt to_dt() {
+ if constexpr (std::is_same_v<T, float>) return dt::f32;
+ else if constexpr (std::is_same_v<T, sycl::half>) return dt::f16;
+ else static_assert(0);
+ }
+
+ static inline void row_gemm(sycl::queue& q, bool a_trans,
+ bool b_trans, int m, int n, int k,
+ const void* a, dt at, const void* b, dt bt, void* c, dt ct)
+ {
+ // Get the device associated with the queue
+ sycl::device dev = q.get_device();
+ // Get the context associated with the queue
+ sycl::context ctx = q.get_context();
+ const dnnl::engine eng = dnnl::sycl_interop::make_engine(dev, ctx);
+ const dnnl::stream stream = dnnl::sycl_interop::make_stream(eng, q);
+ dnnl::memory::dims a_dims = { m, k };
+ dnnl::memory::dims b_dims = { k, n };
+ dnnl::memory::dims c_dims = { m, n };
+ const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
+ const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
+ const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
+ auto a_mem = dnnl::memory(a_in_md, eng, (void*)a);
+ auto b_mem = dnnl::memory(b_in_md, eng, (void*)b);
+ auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
+ auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
+
+ // Create the primitive.
+ auto matmul_prim = dnnl::matmul(matmul_pd);
+ // Primitive arguments.
+ std::unordered_map<int, dnnl::memory> matmul_args;
+ matmul_args.insert({ DNNL_ARG_SRC, a_mem });
+ matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
+ matmul_args.insert({ DNNL_ARG_DST, c_mem });
+
+ matmul_prim.execute(stream, matmul_args);
+ }
+
+
+ static inline void row_gemm(const dnnl::stream& stream, bool a_trans,
+ bool b_trans, int m, int n, int k,
+ const void* a, dt at, const void* b, dt bt, void* c, dt ct)
+ {
+ auto const eng = stream.get_engine();
+ dnnl::memory::dims a_dims = { m, k };
+ dnnl::memory::dims b_dims = { k, n };
+ dnnl::memory::dims c_dims = { m, n };
+ const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
+ const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
+ const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
+ auto a_mem = dnnl::memory(a_in_md, eng, (void*)a);
+ auto b_mem = dnnl::memory(b_in_md, eng, (void*)b);
+ auto matmul_pd = dnnl::matmul::primitive_desc(eng, a_in_md, b_in_md, c_md);
+ auto c_mem = dnnl::memory(matmul_pd.dst_desc(), eng, c);
+
+ // Create the primitive.
+ auto matmul_prim = dnnl::matmul(matmul_pd);
+ // Primitive arguments.
+ std::unordered_map<int, dnnl::memory> matmul_args;
+ matmul_args.insert({ DNNL_ARG_SRC, a_mem });
+ matmul_args.insert({ DNNL_ARG_WEIGHTS, b_mem });
+ matmul_args.insert({ DNNL_ARG_DST, c_mem });
+
+ matmul_prim.execute(stream, matmul_args);
+ }
+};
+
+#endif
+
+#endif // GGML_SYCL_GEMM_HPP