]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
SYCL: Rename oneMKL to oneMath (llama/12192)
authorRomain Biessy <redacted>
Tue, 1 Apr 2025 08:24:29 +0000 (10:24 +0200)
committerGeorgi Gerganov <redacted>
Wed, 2 Apr 2025 12:51:57 +0000 (15:51 +0300)
* Rename oneMKL Interface to oneMath

* Use oneMath for Intel vendor

* Rename occurences to mkl

* clang-format

* Silence verbose warnings

* Set oneMath HIP_TARGETS

* Fix silence warnings

* Remove step to build oneMath from build instructions

* Use fixed oneMath version

* Remove INTEL_CPU

* Fold CMake oneDNN conditions

* Use Intel oneMKL for Intel devices

* Improve CMake message

* Link against MKL::MKL_SYCL::BLAS only

* Move oneMath documentation to Nvidia and AMD sections

ggml/src/ggml-sycl/CMakeLists.txt
ggml/src/ggml-sycl/dpct/helper.hpp
ggml/src/ggml-sycl/ggml-sycl.cpp
ggml/src/ggml-sycl/outprod.cpp

index f713fbe46e0125ed1cfe63300e74060ab356fe55..6747fd88361f7d0b1e01a9c0da300740819a0dff 100644 (file)
@@ -23,6 +23,23 @@ ggml_add_backend_library(ggml-sycl
                          ../../include/ggml-sycl.h
                         )
 
+file(GLOB   GGML_HEADERS_SYCL "*.hpp")
+file(GLOB   GGML_SOURCES_SYCL "*.cpp")
+target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL})
+
+find_package(IntelSYCL)
+if (IntelSYCL_FOUND)
+    # Use oneAPI CMake when possible
+    target_link_libraries(ggml-sycl PRIVATE IntelSYCL::SYCL_CXX)
+else()
+    # Fallback to the simplest way of enabling SYCL when using intel/llvm nightly for instance
+    target_compile_options(ggml-sycl PRIVATE "-fsycl")
+    target_link_options(ggml-sycl PRIVATE "-fsycl")
+endif()
+
+target_compile_options(ggml-sycl PRIVATE "-Wno-narrowing")
+
+# Link against oneDNN
 find_package(DNNL)
 set(GGML_SYCL_DNNL 0)
 if(DNNL_FOUND)
@@ -62,8 +79,6 @@ if (GGML_SYCL_F16)
     add_compile_definitions(GGML_SYCL_F16)
 endif()
 
-set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing -fsycl")
-
 if (GGML_SYCL_TARGET STREQUAL "NVIDIA")
     add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
 elseif (GGML_SYCL_TARGET STREQUAL "AMD")
@@ -76,34 +91,84 @@ else()
     add_compile_definitions(GGML_SYCL_WARP_SIZE=16)
 endif()
 
-file(GLOB   GGML_HEADERS_SYCL "*.hpp")
-file(GLOB   GGML_SOURCES_SYCL "*.cpp")
-target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL})
-
+if (GGML_SYCL_GRAPH)
+    target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_GRAPH)
+endif()
 
-if (WIN32)
-    find_package(IntelSYCL REQUIRED)
+# Link against Intel oneMKL or oneMath
+if (GGML_SYCL_TARGET STREQUAL "INTEL")
+    # Intel devices use Intel oneMKL directly instead of oneMath to avoid the limitation of linking Intel oneMKL statically
+    # See https://github.com/uxlfoundation/oneMath/issues/654
     find_package(MKL REQUIRED)
-    target_link_libraries(ggml-sycl PRIVATE IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL)
+    target_link_libraries(ggml-sycl PRIVATE MKL::MKL_SYCL::BLAS)
+    target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_USE_INTEL_ONEMKL)
 else()
-    if (GGML_SYCL_GRAPH)
-        add_compile_definitions(GGML_SYCL_GRAPH)
+    find_package(oneMath QUIET)
+    if (NOT oneMath_FOUND)
+        message(STATUS "oneMath not found: oneMath will be automatically downloaded")
+        # Use FetchContent to automatically pull and build oneMath
+        include(FetchContent)
+        set(BUILD_FUNCTIONAL_TESTS False)
+        set(BUILD_EXAMPLES False)
+        set(TARGET_DOMAINS blas)
+        if (GGML_SYCL_TARGET STREQUAL "NVIDIA")
+            set(ENABLE_MKLCPU_BACKEND False)
+            set(ENABLE_MKLGPU_BACKEND False)
+            set(ENABLE_CUBLAS_BACKEND True)
+        elseif (GGML_SYCL_TARGET STREQUAL "AMD")
+            set(ENABLE_MKLCPU_BACKEND False)
+            set(ENABLE_MKLGPU_BACKEND False)
+            set(ENABLE_ROCBLAS_BACKEND True)
+            # Ensure setting a string variable here is not overriden by oneMath CACHE variables
+            cmake_policy(SET CMP0126 NEW)
+            # Setting the device architecture is only needed and useful for AMD devices in oneMath
+            set(HIP_TARGETS ${GGML_SYCL_DEVICE_ARCH} CACHE STRING "oneMath HIP target" FORCE)
+        endif()
+        FetchContent_Declare(
+            ONEMATH
+            GIT_REPOSITORY https://github.com/uxlfoundation/oneMath.git
+            GIT_TAG c255b1b4c41e2ee3059455c1f96a965d6a62568a
+        )
+        FetchContent_MakeAvailable(ONEMATH)
+        # Create alias to match with find_package targets name
+        function(onemath_alias target)
+            if (TARGET ${target}_obj)
+                # Silence verbose warnings from external libraries
+                target_compile_options(${target}_obj PRIVATE -w)
+            endif()
+            if (TARGET ${target})
+                add_library(ONEMATH::${target} ALIAS ${target})
+            endif()
+        endfunction()
+        onemath_alias(onemath)
+        onemath_alias(onemath_blas_mklcpu)
+        onemath_alias(onemath_blas_mklgpu)
+        onemath_alias(onemath_blas_cublas)
+        onemath_alias(onemath_blas_rocblas)
     endif()
-    if (GGML_SYCL_TARGET STREQUAL "INTEL")
-        target_link_libraries(ggml-sycl PRIVATE sycl OpenCL mkl_core pthread m dl mkl_sycl_blas mkl_intel_ilp64 mkl_tbb_thread)
-    elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA")
-        set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=nvptx64-nvidia-cuda")
-        add_compile_definitions(GGML_SYCL_NVIDIA)
-        target_link_libraries(ggml-sycl PRIVATE sycl pthread m dl onemkl_blas_cublas)
+
+    # Below oneMath compile-time dispatching is used for better performance
+    if (GGML_SYCL_TARGET STREQUAL "NVIDIA")
+        target_link_libraries(ggml-sycl PRIVATE ONEMATH::onemath_blas_cublas)
+        target_compile_options(ggml-sycl PRIVATE "-fsycl-targets=nvptx64-nvidia-cuda")
+        target_link_options(ggml-sycl PRIVATE "-fsycl-targets=nvptx64-nvidia-cuda")
+        target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_NVIDIA)
     elseif (GGML_SYCL_TARGET STREQUAL "AMD")
         if (NOT GGML_SYCL_DEVICE_ARCH)
             message(ERROR "Can't enable SYCL hip backend, GGML_SYCL_DEVICE_ARCH has not been set.")
         endif()
-        set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsycl-targets=amdgcn-amd-amdhsa")
-        target_link_libraries(ggml-sycl PRIVATE sycl pthread m dl onemkl)
+        target_link_libraries(ggml-sycl PRIVATE ONEMATH::onemath_blas_rocblas)
+        target_compile_options(ggml-sycl PRIVATE "-fsycl-targets=amdgcn-amd-amdhsa")
+        target_link_options(ggml-sycl PRIVATE "-fsycl-targets=amdgcn-amd-amdhsa")
+        target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_AMD)
+    else()
+        # Fallback to oneMath runtime dispatcher
+        target_link_libraries(ggml-sycl PRIVATE ONEMATH::onemath)
+        target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_GENERIC)
     endif()
+endif()
 
-    if (GGML_SYCL_DEVICE_ARCH)
-      set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Xsycl-target-backend --offload-arch=${GGML_SYCL_DEVICE_ARCH}")
-  endif()
+if (GGML_SYCL_DEVICE_ARCH)
+    target_compile_options(ggml-sycl PRIVATE -Xsycl-target-backend --offload-arch=${GGML_SYCL_DEVICE_ARCH})
+    target_link_options(ggml-sycl PRIVATE -Xsycl-target-backend --offload-arch=${GGML_SYCL_DEVICE_ARCH})
 endif()
index c96395be613126282a6ed425a6aba9400787bd29..d538965b096bf3e14ef483b4e0a2ee466aece834 100644 (file)
 #include <sycl/sycl.hpp>
 #include <sycl/half_type.hpp>
 #include <syclcompat/math.hpp>
-#include <oneapi/mkl.hpp>
 #include <map>
 
+#ifdef GGML_SYCL_USE_INTEL_ONEMKL
+#include <oneapi/mkl.hpp>
+// Allow to use the same namespace for Intel oneMKL and oneMath
+namespace oneapi {
+    namespace math = mkl;
+}
+#else
+#include <oneapi/math.hpp>
+#endif
+
 #include "ggml.h"
 
 #if defined(__linux__)
@@ -83,13 +92,32 @@ inline std::string get_device_backend_and_type(const sycl::device &device) {
 }
 
 template <typename Ts> struct matrix_info_t {
-    oneapi::mkl::transpose transpose_info[2];
+    oneapi::math::transpose transpose_info[2];
     Ts                     value_info[2];
     std::int64_t           size_info[3];
     std::int64_t           ld_info[3];
     std::int64_t           groupsize_info;
 };
 
+inline auto get_onemath_backend(sycl::queue& queue)
+#if defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL)
+  -> sycl::queue&
+#endif
+{
+// If the backend is known at compile-time, use oneMath backend_selector to use
+// compile-time dispatching and avoid the need to dlopen libraries. Otherwise
+// fallback to runtime dispatching.
+#if defined(GGML_SYCL_NVIDIA)
+    return oneapi::math::backend_selector<oneapi::math::backend::cublas>{ queue };
+#elif defined(GGML_SYCL_AMD)
+    return oneapi::math::backend_selector<oneapi::math::backend::rocblas>{ queue };
+#elif defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL)
+    return queue;
+#else
+    static_assert(false, "Unsupported backend");
+#endif
+}
+
 namespace dpct
 {
     typedef sycl::queue *queue_ptr;
@@ -1686,26 +1714,18 @@ namespace dpct
 
     namespace detail
     {
-        template <class Ta, class Tb, class Tc, class Ts>
-        inline void gemm_impl(sycl::queue &q, oneapi::mkl::transpose a_trans,
-                              oneapi::mkl::transpose b_trans, int m, int n, int k,
-                              const void *alpha, const void *a, int lda, const void *b,
-                              int ldb, const void *beta, void *c, int ldc)
-        {
-            Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
-            Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
-            auto data_a = get_memory<const Ta>(a);
-            auto data_b = get_memory<const Tb>(b);
-            auto data_c = get_memory<Tc>(c);
-#ifdef GGML_SYCL_NVIDIA
-            oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q },
-                                                  a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
-                                                  beta_value, data_c, ldc);
-#else
-            oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda, data_b, ldb,
-                                                  beta_value, data_c, ldc);
-#endif
-        }
+    template <class Ta, class Tb, class Tc, class Ts>
+    inline void gemm_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
+                          int n, int k, const void * alpha, const void * a, int lda, const void * b, int ldb,
+                          const void * beta, void * c, int ldc) {
+        Ts   alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
+        Ts   beta_value  = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
+        auto data_a      = get_memory<const Ta>(a);
+        auto data_b      = get_memory<const Tb>(b);
+        auto data_c      = get_memory<Tc>(c);
+        oneapi::math::blas::column_major::gemm(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value, data_a,
+                                               lda, data_b, ldb, beta_value, data_c, ldc);
+    }
 
         template <typename VecT, class BinaryOperation, class = void>
         class vectorized_binary
@@ -1735,7 +1755,7 @@ namespace dpct
         };
 
         template <class Ta, class Tb, class Tc, class Ts>
-        inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
+        inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans,
                                     int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
                                     int ldb, const void * beta, void ** c, int ldc, int batch_size,
                                     matrix_info_t<float> * matrix_info) {
@@ -1754,48 +1774,28 @@ namespace dpct
             matrix_info->ld_info[2] = ldc;
             matrix_info->groupsize_info = batch_size;
 
-#ifdef GGML_SYCL_NVIDIA
-            sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
-                oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, matrix_info->transpose_info,
-                matrix_info->transpose_info + 1, matrix_info->size_info, matrix_info->size_info + 1,
-                matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
-                reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
-                matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1),
-                reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
-#else
-            sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
-                q, matrix_info->transpose_info, matrix_info->transpose_info + 1, matrix_info->size_info,
-                matrix_info->size_info + 1, matrix_info->size_info + 2, reinterpret_cast<Ts *>(matrix_info->value_info),
-                reinterpret_cast<const Ta **>(a), matrix_info->ld_info, reinterpret_cast<const Tb **>(b),
-                matrix_info->ld_info + 1, reinterpret_cast<Ts *>(matrix_info->value_info + 1),
-                reinterpret_cast<Tc **>(c), matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
-#endif
+            sycl::event e = oneapi::math::blas::column_major::gemm_batch(
+                get_onemath_backend(q), matrix_info->transpose_info, matrix_info->transpose_info + 1,
+                matrix_info->size_info, matrix_info->size_info + 1, matrix_info->size_info + 2,
+                reinterpret_cast<Ts *>(matrix_info->value_info), reinterpret_cast<const Ta **>(a), matrix_info->ld_info,
+                reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
+                reinterpret_cast<Ts *>(matrix_info->value_info + 1), reinterpret_cast<Tc **>(c),
+                matrix_info->ld_info + 2, 1, &(matrix_info->groupsize_info));
         }
 
         template <class Ta, class Tb, class Tc, class Ts>
-        inline void
-        gemm_batch_impl(sycl::queue &q, oneapi::mkl::transpose a_trans,
-                        oneapi::mkl::transpose b_trans, int m, int n,
-                        int k, const void *alpha, const void *a, int lda,
-                        long long int stride_a, const void *b, int ldb,
-                        long long int stride_b, const void *beta, void *c,
-                        int ldc, long long int stride_c, int batch_size)
-        {
+        inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans,
+                                    int m, int n, int k, const void * alpha, const void * a, int lda,
+                                    long long int stride_a, const void * b, int ldb, long long int stride_b,
+                                    const void * beta, void * c, int ldc, long long int stride_c, int batch_size) {
             Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
             Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
             auto data_a = get_memory<const Ta>(a);
             auto data_b = get_memory<const Tb>(b);
             auto data_c = get_memory<Tc>(c);
-#ifdef GGML_SYCL_NVIDIA
-            oneapi::mkl::blas::column_major::gemm_batch(
-                oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ q }, a_trans, b_trans, m, n, k,
-                alpha_value, data_a, lda, stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc, stride_c,
-                batch_size);
-#else
-            oneapi::mkl::blas::column_major::gemm_batch(q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
-                                                        stride_a, data_b, ldb, stride_b, beta_value, data_c, ldc,
-                                                        stride_c, batch_size);
-#endif
+            oneapi::math::blas::column_major::gemm_batch(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value,
+                                                         data_a, lda, stride_a, data_b, ldb, stride_b, beta_value,
+                                                         data_c, ldc, stride_c, batch_size);
         }
 
     } // namespace detail
@@ -2259,13 +2259,10 @@ namespace dpct
                            sycl::range<3>(x, y, 1), direction);
     }
 
-    inline void gemm(sycl::queue &q, oneapi::mkl::transpose a_trans,
-                     oneapi::mkl::transpose b_trans, int m, int n, int k,
-                     const void *alpha, const void *a, library_data_t a_type,
-                     int lda, const void *b, library_data_t b_type, int ldb,
-                     const void *beta, void *c, library_data_t c_type, int ldc,
-                     library_data_t scaling_type)
-    {
+    inline void gemm(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m, int n,
+                     int k, const void * alpha, const void * a, library_data_t a_type, int lda, const void * b,
+                     library_data_t b_type, int ldb, const void * beta, void * c, library_data_t c_type, int ldc,
+                     library_data_t scaling_type) {
         if (scaling_type == library_data_t::real_float &&
             c_type == library_data_t::complex_float)
         {
@@ -2329,9 +2326,8 @@ namespace dpct
             library_data_t::real_bfloat16, library_data_t::real_bfloat16,
             library_data_t::real_float, library_data_t::real_float):
         {
-            detail::gemm_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float,
-                              float>(q, a_trans, b_trans, m, n, k, alpha, a, lda, b,
-                                     ldb, beta, c, ldc);
+            detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
+                q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
             break;
         }
         case detail::get_type_combination_id(
@@ -2369,8 +2365,7 @@ namespace dpct
             library_data_t::real_bfloat16, library_data_t::real_bfloat16,
             library_data_t::real_bfloat16, library_data_t::real_float):
         {
-            detail::gemm_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
-                              oneapi::mkl::bfloat16, float>(
+            detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
             break;
         }
@@ -2390,7 +2385,7 @@ namespace dpct
         default:
             throw std::runtime_error("the combination of data type is unsupported");
         }
-    } // gemm()
+    }  // gemm()
 
     /// Computes a batch of matrix-matrix product with general matrices.
     /// \param [in] q The queue where the routine should be executed.
@@ -2412,7 +2407,7 @@ namespace dpct
     /// \param [in] ldc Leading dimension of C.
     /// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
     /// \param [in] scaling_type Data type of the scaling factors.
-    inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
+    inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
                            int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
                            const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
                            library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
@@ -2450,7 +2445,7 @@ namespace dpct
             library_data_t::real_bfloat16, library_data_t::real_bfloat16,
             library_data_t::real_bfloat16, library_data_t::real_float):
         {
-            detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
+            detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
             break;
         }
@@ -2458,7 +2453,7 @@ namespace dpct
             library_data_t::real_bfloat16, library_data_t::real_bfloat16,
             library_data_t::real_float, library_data_t::real_float):
         {
-            detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
+            detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
             break;
         }
@@ -2534,15 +2529,11 @@ namespace dpct
     /// \param [in] stride_c Stride between the different C matrices.
     /// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
     /// \param [in] scaling_type Data type of the scaling factors.
-    inline void gemm_batch(sycl::queue &q, oneapi::mkl::transpose a_trans,
-                           oneapi::mkl::transpose b_trans, int m, int n, int k,
-                           const void *alpha, const void *a, library_data_t a_type,
-                           int lda, long long int stride_a, const void *b,
-                           library_data_t b_type, int ldb, long long int stride_b,
-                           const void *beta, void *c, library_data_t c_type,
-                           int ldc, long long int stride_c, int batch_size,
-                           library_data_t scaling_type)
-    {
+    inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
+                           int n, int k, const void * alpha, const void * a, library_data_t a_type, int lda,
+                           long long int stride_a, const void * b, library_data_t b_type, int ldb,
+                           long long int stride_b, const void * beta, void * c, library_data_t c_type, int ldc,
+                           long long int stride_c, int batch_size, library_data_t scaling_type) {
         if (scaling_type == library_data_t::real_float &&
             c_type == library_data_t::complex_float)
         {
@@ -2611,20 +2602,18 @@ namespace dpct
             library_data_t::real_bfloat16, library_data_t::real_bfloat16,
             library_data_t::real_bfloat16, library_data_t::real_float):
         {
-            detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16,
-                                    oneapi::mkl::bfloat16, float>(
-                q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b,
-                beta, c, ldc, stride_c, batch_size);
+            detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
+                q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
+                batch_size);
             break;
         }
         case detail::get_type_combination_id(
             library_data_t::real_bfloat16, library_data_t::real_bfloat16,
             library_data_t::real_float, library_data_t::real_float):
         {
-            detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float,
-                                    float>(q, a_trans, b_trans, m, n, k, alpha, a, lda,
-                                           stride_a, b, ldb, stride_b, beta, c, ldc,
-                                           stride_c, batch_size);
+            detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
+                q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
+                batch_size);
             break;
         }
 #endif
index ab8efba8165e991c15cc8e17068473c1b9c0679d..dff9f8d4c4ac27c28c2a4b94dc8f2c02e38fd70e 100644 (file)
@@ -2059,8 +2059,8 @@ inline void ggml_sycl_op_mul_mat_sycl(
         const sycl::half alpha_f16 = 1.0f;
         const sycl::half beta_f16  = 0.0f;
         SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
-            *stream, oneapi::mkl::transpose::trans,
-            oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
+            *stream, oneapi::math::transpose::trans,
+            oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10,
             &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
             src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
             dst_f16.get(), dpct::library_data_t::real_half, ldc,
@@ -2097,17 +2097,10 @@ inline void ggml_sycl_op_mul_mat_sycl(
 #if !GGML_SYCL_DNNL
         const float alpha = 1.0f;
         const float beta  = 0.0f;
-#    ifdef GGML_SYCL_NVIDIA
-        SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
-            oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream }, oneapi::mkl::transpose::trans,
-            oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i,
-            ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
-#    else
-        SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
-            *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
-            dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
-            dst_dd_i, ldc)));
-#    endif
+        SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm(
+            get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff,
+            src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
+            dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
 #else
         DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i,
                                   DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
@@ -2836,14 +2829,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
     if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
         // there is no broadcast and src0, src1 are contiguous across dims 2, 3
         SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
-            *main_stream, oneapi::mkl::transpose::trans,
-            oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
-            (const char *)src0_as_f16, dpct::library_data_t::real_half,
-            nb01 / nb00, nb02 / nb00,
-            (const char *)src1_f16, dpct::library_data_t::real_half,
-            nb11 / nb10, nb12 / nb10, beta,
-            (char *)dst_t, cu_data_type, ne01, nb2 / nb0,
-            ne12 * ne13, cu_compute_type)));
+            *main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
+            (const char *) src0_as_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
+            (const char *) src1_f16, dpct::library_data_t::real_half, nb11 / nb10, nb12 / nb10, beta, (char *) dst_t,
+            cu_data_type, ne01, nb2 / nb0, ne12 * ne13, cu_compute_type)));
     } else {
         const int ne23 = ne12*ne13;
 
@@ -2878,7 +2867,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
             });
         }
         SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
-            *main_stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
+            *main_stream, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
             (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
             (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, nb11 / nb10, beta,
             (void **) (ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23, cu_compute_type, matrix_info.get())));
index 8e8347ff4f95ef2c6eea947ef8cc78b7adf1b2bc..b60415784f32db1b02f1b833e5e5baefe1971cd6 100644 (file)
@@ -1,8 +1,5 @@
-#include <sycl/sycl.hpp>
-#include <oneapi/mkl.hpp>
 #include "outprod.hpp"
 
-
 void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
     const ggml_tensor *src0 = dst->src[0];
     const ggml_tensor *src1 = dst->src[1];
@@ -34,20 +31,13 @@ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
 
     // Handle transposition of src1
     const bool src1_T = ggml_is_transposed(src1);
-    const oneapi::mkl::transpose src1_op =
-        src1_T ? oneapi::mkl::transpose::nontrans : oneapi::mkl::transpose::trans;
+    const oneapi::math::transpose src1_op = src1_T ? oneapi::math::transpose::nontrans : oneapi::math::transpose::trans;
     const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
 
     try {
-        // Perform matrix multiplication using oneMKL GEMM
-#ifdef GGML_SYCL_NVIDIA
-        oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream },
-                                              oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha, src0_d,
-                                              ne00, src1_d, ldb, beta, dst_d, ne0);
-#else
-        oneapi::mkl::blas::column_major::gemm(*stream, oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha,
-                                              src0_d, ne00, src1_d, ldb, beta, dst_d, ne0);
-#endif
+        // Perform matrix multiplication using oneMath GEMM
+        oneapi::math::blas::column_major::gemm(get_onemath_backend(*stream), oneapi::math::transpose::nontrans, src1_op,
+                                               ne0, ne1, ne01, alpha, src0_d, ne00, src1_d, ldb, beta, dst_d, ne0);
     }
     catch (sycl::exception const& exc) {
         std::cerr << exc.what() << std::endl;