]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Remove support for Nvidia & AMD GPU, because the oneAPI plugin for Nvidia & AMD GPU...
authorNeo Zhang <redacted>
Mon, 2 Feb 2026 13:06:21 +0000 (21:06 +0800)
committerGitHub <redacted>
Mon, 2 Feb 2026 13:06:21 +0000 (21:06 +0800)
User can't build up the software for Nvidia & AMD GPU.
rm the oneMath since it is only used in NV and AMD code path.

docs/backend/SYCL.md
ggml/src/ggml-sycl/CMakeLists.txt
ggml/src/ggml-sycl/dpct/helper.hpp
ggml/src/ggml-sycl/ggml-sycl.cpp
ggml/src/ggml-sycl/outprod.cpp
ggml/src/ggml-sycl/rope.cpp
ggml/src/ggml-sycl/wkv.cpp

index 10cb02ff2cb66d9366d87ef2beab72433acb8f85..b3cff96604ec503237f7e3fed2f0d3aa139da8ad 100644 (file)
 - **DPCPP** *(Data Parallel C++)*: The primary oneAPI SYCL implementation, which includes the icpx/icx Compilers.
 - **oneAPI Libraries**: A set of highly optimized libraries targeting multiple domains *(e.g. Intel oneMKL, oneMath and oneDNN)*.
 - **oneAPI LevelZero**: A high performance low level interface for fine-grained control over Intel iGPUs and dGPUs.
-- **Nvidia & AMD Plugins**: These are plugins extending oneAPI's DPCPP support to SYCL on Nvidia and AMD GPU targets.
 
 ### Llama.cpp + SYCL
 
 The llama.cpp SYCL backend is primarily designed for **Intel GPUs**.
-SYCL cross-platform capabilities enable support for Nvidia GPUs as well, with limited support for AMD.
+SYCL cross-platform capabilities enable support for other vendor GPUs as well.
 
 ## Recommended Release
 
@@ -42,6 +41,9 @@ The following releases are verified and recommended:
 
 ## News
 
+- 2026.02
+  - Remove support for Nvidia & AMD GPU, because the oneAPI plugin for Nvidia & AMD GPU is unavailable: download/installation channels are out of work. User can't build up the software for Nvidia & AMD GPU.
+
 - 2025.11
   - Support malloc memory on device more than 4GB.
 
@@ -111,8 +113,8 @@ On older Intel GPUs, you may try [OpenCL](/docs/backend/OPENCL.md) although the
 |-------------------------------|---------|---------------------------------------|
 | Intel Data Center Max Series  | Support | Max 1550, 1100                        |
 | Intel Data Center Flex Series | Support | Flex 170                              |
-| Intel Arc A-Series              | Support | Arc A770, Arc A730M, Arc A750         |
-| Intel Arc B-Series              | Support | Arc B580         |
+| Intel Arc A-Series            | Support | Arc A770, Arc A730M, Arc A750         |
+| Intel Arc B-Series            | Support | Arc B580                              |
 | Intel built-in Arc GPU        | Support | built-in Arc GPU in Meteor Lake, Arrow Lake, Lunar Lake |
 | Intel iGPU                    | Support | iGPU in 13700k, 13400, i5-1250P, i7-1260P, i7-1165G7  |
 
@@ -127,20 +129,7 @@ On older Intel GPUs, you may try [OpenCL](/docs/backend/OPENCL.md) although the
 
 ### Other Vendor GPU
 
-**Verified devices**
-
-| Nvidia GPU               | Status    | Verified Model |
-|--------------------------|-----------|----------------|
-| Ampere Series            | Supported | A100, A4000    |
-| Ampere Series *(Mobile)* | Supported | RTX 40 Series  |
-
-| AMD GPU                  | Status       | Verified Model |
-|--------------------------|--------------|----------------|
-| Radeon Pro               | Experimental | W6800          |
-| Radeon RX                | Experimental | 6700 XT        |
-
-Note: AMD GPU support is highly experimental and is incompatible with F16.
-Additionally, it only supports GPUs with a sub_group_size (warp size) of 32.
+NA
 
 ## Docker
 
@@ -149,11 +138,11 @@ The docker build option is currently limited to *Intel GPU* targets.
 ### Build image
 
 ```sh
-# Using FP16
-docker build -t llama-cpp-sycl --build-arg="GGML_SYCL_F16=ON" --target light -f .devops/intel.Dockerfile .
-
 # Using FP32
 docker build -t llama-cpp-sycl --build-arg="GGML_SYCL_F16=OFF" --target light -f .devops/intel.Dockerfile .
+
+# Using FP16
+docker build -t llama-cpp-sycl --build-arg="GGML_SYCL_F16=ON" --target light -f .devops/intel.Dockerfile .
 ```
 
 *Notes*:
@@ -212,14 +201,6 @@ Platform #0: Intel(R) OpenCL HD Graphics
  `-- Device #0: Intel(R) Iris(R) Xe Graphics [0x9a49]
 ```
 
-- **Nvidia GPU**
-
-In order to target Nvidia GPUs through SYCL, please make sure the CUDA/CUBLAS native requirements *-found [here](README.md#cuda)-* are installed.
-
-- **AMD GPU**
-
-To target AMD GPUs with SYCL, the ROCm stack must be installed first.
-
 2. **Install IntelĀ® oneAPI Base toolkit**
 
 SYCL backend depends on:
@@ -248,23 +229,6 @@ Upon a successful installation, SYCL is enabled for the available intel devices,
 |2025.1|
 |2024.1|
 
-- **Adding support to Nvidia GPUs**
-
-**oneAPI Plugin**: In order to enable SYCL support on Nvidia GPUs, please install the [Codeplay oneAPI Plugin for Nvidia GPUs](https://developer.codeplay.com/products/oneapi/nvidia/download). User should also make sure the plugin version matches the installed base toolkit one *(previous step)* for a seamless "oneAPI on Nvidia GPU" setup.
-
-**oneDNN**: The current oneDNN releases *(shipped with the oneAPI base-toolkit)* do not include the NVIDIA backend. Therefore, oneDNN must be compiled from source to enable the NVIDIA target:
-
-```sh
-git clone https://github.com/oneapi-src/oneDNN.git
-cd oneDNN
-cmake -GNinja -Bbuild-nvidia -DDNNL_CPU_RUNTIME=DPCPP -DDNNL_GPU_RUNTIME=DPCPP -DDNNL_GPU_VENDOR=NVIDIA -DONEDNN_BUILD_GRAPH=OFF -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx
-cmake --build build-nvidia --config Release
-```
-
-- **Adding support to AMD GPUs**
-
-**oneAPI Plugin**: In order to enable SYCL support on AMD GPUs, please install the [Codeplay oneAPI Plugin for AMD GPUs](https://developer.codeplay.com/products/oneapi/amd/download). As with Nvidia GPUs, the user should also make sure the plugin version matches the installed base toolkit.
-
 3. **Verify installation and environment**
 
 In order to check the available SYCL devices on the machine, please use the `sycl-ls` command.
@@ -285,25 +249,6 @@ When targeting an intel GPU, the user should expect one or more devices among th
 [opencl:gpu][opencl:2] Intel(R) OpenCL Graphics, Intel(R) UHD Graphics 730 OpenCL 3.0 NEO  [24.39.31294]
 ```
 
-- **Nvidia GPU**
-
-Similarly, user targeting Nvidia GPUs should expect at least one SYCL-CUDA device [`cuda:gpu`] as below:
-
-```
-[opencl:acc][opencl:0] Intel(R) FPGA Emulation Platform for OpenCL(TM), Intel(R) FPGA Emulation Device OpenCL 1.2  [2023.16.12.0.12_195853.xmain-hotfix]
-[opencl:cpu][opencl:1] Intel(R) OpenCL, Intel(R) Xeon(R) Gold 6326 CPU @ 2.90GHz OpenCL 3.0 (Build 0) [2023.16.12.0.12_195853.xmain-hotfix]
-[cuda:gpu][cuda:0] NVIDIA CUDA BACKEND, NVIDIA A100-PCIE-40GB 8.0 [CUDA 12.5]
-```
-
-- **AMD GPU**
-
-For AMD GPUs we should expect at least one SYCL-HIP device [`hip:gpu`]:
-
-```
-[opencl:cpu][opencl:0] Intel(R) OpenCL, 12th Gen Intel(R) Core(TM) i9-12900K OpenCL 3.0 (Build 0) [2024.18.6.0.02_160000]
-[hip:gpu][hip:0] AMD HIP BACKEND, AMD Radeon PRO W6800 gfx1030 [HIP 60140.9]
-```
-
 ### II. Build llama.cpp
 
 #### Intel GPU
@@ -332,47 +277,6 @@ It is possible to come across some precision issues when running tests that stem
 instructions, which can be circumvented by setting the environment variable `SYCL_PROGRAM_COMPILE_OPTIONS`
 as `-cl-fp32-correctly-rounded-divide-sqrt`
 
-#### Nvidia GPU
-
-The SYCL backend depends on [oneMath](https://github.com/uxlfoundation/oneMath) for Nvidia and AMD devices.
-By default it is automatically built along with the project. A specific build can be provided by setting the CMake flag `-DoneMath_DIR=/path/to/oneMath/install/lib/cmake/oneMath`.
-
-```sh
-# Build LLAMA with Nvidia BLAS acceleration through SYCL
-# Setting GGML_SYCL_DEVICE_ARCH is optional but can improve performance
-GGML_SYCL_DEVICE_ARCH=sm_80 # Example architecture
-
-# Option 1: Use FP32 (recommended for better performance in most cases)
-cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=NVIDIA -DGGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DDNNL_DIR=/path/to/oneDNN/build-nvidia/install/lib/cmake/dnnl
-
-# Option 2: Use FP16
-cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=NVIDIA -DGGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON -DDNNL_DIR=/path/to/oneDNN/build-nvidia/install/lib/cmake/dnnl
-
-# build all binary
-cmake --build build --config Release -j -v
-```
-
-It is possible to come across some precision issues when running tests that stem from using faster
-instructions, which can be circumvented by passing the `-fno-fast-math` flag to the compiler.
-
-#### AMD GPU
-
-The SYCL backend depends on [oneMath](https://github.com/uxlfoundation/oneMath) for Nvidia and AMD devices.
-By default it is automatically built along with the project. A specific build can be provided by setting the CMake flag `-DoneMath_DIR=/path/to/oneMath/install/lib/cmake/oneMath`.
-
-```sh
-# Build LLAMA with rocBLAS acceleration through SYCL
-
-## AMD
-# Use FP32, FP16 is not supported
-# Find your GGML_SYCL_DEVICE_ARCH with rocminfo, under the key 'Name:'
-GGML_SYCL_DEVICE_ARCH=gfx90a # Example architecture
-cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=AMD -DGGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx
-
-# build all binary
-cmake --build build --config Release -j -v
-```
-
 ### III. Run the inference
 
 #### Retrieve and prepare model
@@ -766,15 +670,15 @@ use 1 SYCL GPUs: [0] with Max compute units:512
 | Name               | Value                                 | Function                                    |
 |--------------------|---------------------------------------|---------------------------------------------|
 | GGML_SYCL          | ON (mandatory)                        | Enable build with SYCL code path.           |
-| GGML_SYCL_TARGET   | INTEL *(default)* \| NVIDIA \| AMD    | Set the SYCL target device type.            |
-| GGML_SYCL_DEVICE_ARCH | Optional (except for AMD)             | Set the SYCL device architecture, optional except for AMD. Setting the device architecture can improve the performance. See the table [--offload-arch](https://github.com/intel/llvm/blob/sycl/sycl/doc/design/OffloadDesign.md#--offload-arch) for a list of valid architectures. |
+| GGML_SYCL_TARGET   | INTEL *(default)*                     | Set the SYCL target device type.            |
+| GGML_SYCL_DEVICE_ARCH | Optional                           | Set the SYCL device architecture. Setting the device architecture can improve the performance. See the table [--offload-arch](https://github.com/intel/llvm/blob/sycl/sycl/doc/design/OffloadDesign.md#--offload-arch) for a list of valid architectures. |
 | GGML_SYCL_F16      | OFF *(default)* \|ON *(optional)*     | Enable FP16 build with SYCL code path. (1.) |
-| GGML_SYCL_GRAPH    | ON *(default)* \|OFF *(Optional)*     | Enable build with [SYCL Graph extension](https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc). |
+| GGML_SYCL_GRAPH    | OFF *(default)* \|ON *(Optional)*     | Enable build with [SYCL Graph extension](https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc). |
 | GGML_SYCL_DNN      | ON *(default)* \|OFF *(Optional)*     | Enable build with oneDNN.                   |
 | CMAKE_C_COMPILER   | `icx` *(Linux)*, `icx/cl` *(Windows)* | Set `icx` compiler for SYCL code path.      |
 | CMAKE_CXX_COMPILER | `icpx` *(Linux)*, `icx` *(Windows)*   | Set `icpx/icx` compiler for SYCL code path. |
 
-1. FP16 is recommended for better prompt processing performance on quantized models. Performance is equivalent in text generation but set `GGML_SYCL_F16=OFF` if you are experiencing issues with FP16 builds.
+1. FP32 or FP16 have different performance impact to LLM. Recommended to test them for better prompt processing performance on your models. You need to rebuild the code after change `GGML_SYCL_F16=OFF/ON`.
 
 #### Runtime
 
@@ -782,7 +686,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
 |-------------------|------------------|---------------------------------------------------------------------------------------------------------------------------|
 | GGML_SYCL_DEBUG   | 0 (default) or 1 | Enable log function by macro: GGML_SYCL_DEBUG                                                                             |
 | GGML_SYCL_DISABLE_OPT | 0 (default) or 1 | Disable optimize features for Intel GPUs. (Recommended to 1 for intel devices older than Gen 10) |
-| GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because graph performance isn't yet better than non-graph performance. |
+| GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because SYCL Graph is still on development, no better performance. |
 | GGML_SYCL_DISABLE_DNN | 0 (default) or 1 | Disable running computations through oneDNN and always use oneMKL. |
 | ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer |
 | UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS | 0 (default) or 1 | Support malloc device memory more than 4GB.|
index 5a89d8dd688973174cd8b76467c6324e161a789d..eefdd9725ca8f977b57a8e27eceede4e65ffadda 100644 (file)
@@ -1,7 +1,7 @@
 message(STATUS  "GGML_SYCL_TARGET=${GGML_SYCL_TARGET}")
 
-if (NOT GGML_SYCL_TARGET MATCHES "^(INTEL|NVIDIA|AMD)$")
-    message(FATAL_ERROR "Invalid backend chosen, supported options are INTEL, NVIDIA, or AMD")
+if (NOT GGML_SYCL_TARGET MATCHES "^(INTEL)$")
+    message(FATAL_ERROR "GGML_SYCL_TARGET: Invalid target, the supported options are [INTEL]")
 endif()
 
 check_cxx_compiler_flag("-fsycl" SUPPORTS_SYCL)
@@ -125,106 +125,27 @@ endif()
 target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL})
 
 if (GGML_SYCL_F16)
-    if (GGML_SYCL_TARGET STREQUAL "AMD")
-        message(WARNING "AMD target does not entirely support FP16 in the SYCL backend.")
-    endif()
     add_compile_definitions(GGML_SYCL_F16)
 endif()
 
 if (GGML_SYCL_TARGET STREQUAL "INTEL")
     add_compile_definitions(GGML_SYCL_WARP_SIZE=16)
     target_link_options(ggml-sycl PRIVATE  -Xs   -ze-intel-greater-than-4GB-buffer-required)
-elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA")
-    add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
-elseif (GGML_SYCL_TARGET STREQUAL "AMD")
-    # INFO: Allowed Sub_group_sizes are not consistent through all
-    # hip targets. For example, 64 is used for certain models, but the backend
-    # does not support it.
-    # Target archs tested working: gfx1030, gfx1031, (Only tested sub_group_size = 32)
-    add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
-else()
-    # default for other target
-    add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
-endif()
-
-if (GGML_SYCL_GRAPH)
-    target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_GRAPH)
-endif()
 
-# Link against Intel oneMKL or oneMath
-if (GGML_SYCL_TARGET STREQUAL "INTEL")
-    # Intel devices use Intel oneMKL directly instead of oneMath to avoid the limitation of linking Intel oneMKL statically
-    # See https://github.com/uxlfoundation/oneMath/issues/654
+    # Link against Intel oneMKL
     if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
         set(SYCL_COMPILER ON)
     endif()
     find_package(MKL REQUIRED)
     target_link_libraries(ggml-sycl PRIVATE MKL::MKL_SYCL::BLAS)
-    target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_USE_INTEL_ONEMKL)
 else()
-    find_package(oneMath QUIET)
-    if (NOT oneMath_FOUND)
-        message(STATUS "oneMath not found: oneMath will be automatically downloaded")
-        # Use FetchContent to automatically pull and build oneMath
-        include(FetchContent)
-        set(BUILD_FUNCTIONAL_TESTS False)
-        set(BUILD_EXAMPLES False)
-        set(TARGET_DOMAINS blas)
-        if (GGML_SYCL_TARGET STREQUAL "NVIDIA")
-            set(ENABLE_MKLCPU_BACKEND False)
-            set(ENABLE_MKLGPU_BACKEND False)
-            set(ENABLE_CUBLAS_BACKEND True)
-        elseif (GGML_SYCL_TARGET STREQUAL "AMD")
-            set(ENABLE_MKLCPU_BACKEND False)
-            set(ENABLE_MKLGPU_BACKEND False)
-            set(ENABLE_ROCBLAS_BACKEND True)
-            # Ensure setting a string variable here is not overriden by oneMath CACHE variables
-            cmake_policy(SET CMP0126 NEW)
-            # Setting the device architecture is only needed and useful for AMD devices in oneMath
-            set(HIP_TARGETS ${GGML_SYCL_DEVICE_ARCH} CACHE STRING "oneMath HIP target" FORCE)
-        endif()
-        FetchContent_Declare(
-            ONEMATH
-            GIT_REPOSITORY https://github.com/uxlfoundation/oneMath.git
-            GIT_TAG 8efe85f5aaebb37f1d8c503b7af66315feabf142
-        )
-        FetchContent_MakeAvailable(ONEMATH)
-        # Create alias to match with find_package targets name
-        function(onemath_alias target)
-            if (TARGET ${target}_obj)
-                # Silence verbose warnings from external libraries
-                target_compile_options(${target}_obj PRIVATE -w)
-            endif()
-            if (TARGET ${target})
-                add_library(ONEMATH::${target} ALIAS ${target})
-            endif()
-        endfunction()
-        onemath_alias(onemath)
-        onemath_alias(onemath_blas_mklcpu)
-        onemath_alias(onemath_blas_mklgpu)
-        onemath_alias(onemath_blas_cublas)
-        onemath_alias(onemath_blas_rocblas)
-    endif()
+    # default for other target
+    message(FATAL_ERROR "GGML_SYCL_TARGET is not supported")
+    add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
+endif()
 
-    # Below oneMath compile-time dispatching is used for better performance
-    if (GGML_SYCL_TARGET STREQUAL "NVIDIA")
-        target_link_libraries(ggml-sycl PRIVATE ONEMATH::onemath_blas_cublas)
-        target_compile_options(ggml-sycl PRIVATE "-fsycl-targets=nvptx64-nvidia-cuda")
-        target_link_options(ggml-sycl PRIVATE "-fsycl-targets=nvptx64-nvidia-cuda")
-        target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_NVIDIA)
-    elseif (GGML_SYCL_TARGET STREQUAL "AMD")
-        if (NOT GGML_SYCL_DEVICE_ARCH)
-            message(FATAL_ERROR "Can't enable SYCL hip backend, GGML_SYCL_DEVICE_ARCH has not been set.")
-        endif()
-        target_link_libraries(ggml-sycl PRIVATE ONEMATH::onemath_blas_rocblas)
-        target_compile_options(ggml-sycl PRIVATE "-fsycl-targets=amdgcn-amd-amdhsa")
-        target_link_options(ggml-sycl PRIVATE "-fsycl-targets=amdgcn-amd-amdhsa")
-        target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_AMD)
-    else()
-        # Fallback to oneMath runtime dispatcher
-        target_link_libraries(ggml-sycl PRIVATE ONEMATH::onemath)
-        target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_GENERIC)
-    endif()
+if (GGML_SYCL_GRAPH)
+    target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_GRAPH)
 endif()
 
 if (GGML_SYCL_DEVICE_ARCH)
index 8ae8098717d4ca75edb491b4f8468149c8a7b2d4..ece66a7ac1f438ee7a58f75c0953447a0271a6f5 100644 (file)
 
 #include <sycl/sycl.hpp>
 #include <sycl/half_type.hpp>
-#include <map>
-
-#ifdef GGML_SYCL_USE_INTEL_ONEMKL
 #include <oneapi/mkl.hpp>
-// Allow to use the same namespace for Intel oneMKL and oneMath
-namespace oneapi {
-    namespace math = mkl;
-}
-#else
-#include <oneapi/math.hpp>
-#endif
+
+#include <map>
 
 #include "ggml.h"
 
@@ -91,32 +83,13 @@ inline std::string get_device_backend_and_type(const sycl::device &device) {
 }
 
 template <typename Ts> struct matrix_info_t {
-    oneapi::math::transpose transpose_info[2];
+    oneapi::mkl::transpose transpose_info[2];
     Ts                     value_info[2];
     std::int64_t           size_info[3];
     std::int64_t           ld_info[3];
     std::int64_t           groupsize_info;
 };
 
-inline auto get_onemath_backend(sycl::queue& queue)
-#if defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL)
-  -> sycl::queue&
-#endif
-{
-// If the backend is known at compile-time, use oneMath backend_selector to use
-// compile-time dispatching and avoid the need to dlopen libraries. Otherwise
-// fallback to runtime dispatching.
-#if defined(GGML_SYCL_NVIDIA)
-    return oneapi::math::backend_selector<oneapi::math::backend::cublas>{ queue };
-#elif defined(GGML_SYCL_AMD)
-    return oneapi::math::backend_selector<oneapi::math::backend::rocblas>{ queue };
-#elif defined(GGML_SYCL_GENERIC) || defined(GGML_SYCL_USE_INTEL_ONEMKL)
-    return queue;
-#else
-    static_assert(false, "Unsupported backend");
-#endif
-}
-
 namespace dpct
 {
     typedef sycl::queue *queue_ptr;
@@ -1734,7 +1707,7 @@ namespace dpct
     namespace detail
     {
     template <class Ta, class Tb, class Tc, class Ts>
-    inline void gemm_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
+    inline void gemm_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
                           int n, int k, const void * alpha, const void * a, int lda, const void * b, int ldb,
                           const void * beta, void * c, int ldc) {
         Ts   alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
@@ -1742,7 +1715,7 @@ namespace dpct
         auto data_a      = get_memory<const Ta>(a);
         auto data_b      = get_memory<const Tb>(b);
         auto data_c      = get_memory<Tc>(c);
-        oneapi::math::blas::column_major::gemm(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value, data_a,
+        oneapi::mkl::blas::column_major::gemm(q, a_trans, b_trans, m, n, k, alpha_value, data_a,
                                                lda, data_b, ldb, beta_value, data_c, ldc);
     }
 
@@ -1774,7 +1747,7 @@ namespace dpct
         };
 
         template <class Ta, class Tb, class Tc, class Ts>
-        inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans,
+        inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
                                     int m, int n, int k, const void * alpha, const void ** a, int lda, const void ** b,
                                     int ldb, const void * beta, void ** c, int ldc, int batch_size,
                                     matrix_info_t<float> * matrix_info) {
@@ -1793,8 +1766,8 @@ namespace dpct
             matrix_info->ld_info[2] = ldc;
             matrix_info->groupsize_info = batch_size;
 
-            sycl::event e = oneapi::math::blas::column_major::gemm_batch(
-                get_onemath_backend(q), matrix_info->transpose_info, matrix_info->transpose_info + 1,
+            sycl::event e = oneapi::mkl::blas::column_major::gemm_batch(
+                q, matrix_info->transpose_info, matrix_info->transpose_info + 1,
                 matrix_info->size_info, matrix_info->size_info + 1, matrix_info->size_info + 2,
                 reinterpret_cast<Ts *>(matrix_info->value_info), reinterpret_cast<const Ta **>(a), matrix_info->ld_info,
                 reinterpret_cast<const Tb **>(b), matrix_info->ld_info + 1,
@@ -1803,7 +1776,7 @@ namespace dpct
         }
 
         template <class Ta, class Tb, class Tc, class Ts>
-        inline void gemm_batch_impl(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans,
+        inline void gemm_batch_impl(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans,
                                     int m, int n, int k, const void * alpha, const void * a, int lda,
                                     long long int stride_a, const void * b, int ldb, long long int stride_b,
                                     const void * beta, void * c, int ldc, long long int stride_c, int batch_size) {
@@ -1812,7 +1785,7 @@ namespace dpct
             auto data_a = get_memory<const Ta>(a);
             auto data_b = get_memory<const Tb>(b);
             auto data_c = get_memory<Tc>(c);
-            oneapi::math::blas::column_major::gemm_batch(get_onemath_backend(q), a_trans, b_trans, m, n, k, alpha_value,
+            oneapi::mkl::blas::column_major::gemm_batch(q, a_trans, b_trans, m, n, k, alpha_value,
                                                          data_a, lda, stride_a, data_b, ldb, stride_b, beta_value,
                                                          data_c, ldc, stride_c, batch_size);
         }
@@ -2299,7 +2272,7 @@ namespace dpct
                            sycl::range<3>(x, y, 1), direction);
     }
 
-    inline void gemm(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m, int n,
+    inline void gemm(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m, int n,
                      int k, const void * alpha, const void * a, library_data_t a_type, int lda, const void * b,
                      library_data_t b_type, int ldb, const void * beta, void * c, library_data_t c_type, int ldc,
                      library_data_t scaling_type) {
@@ -2366,7 +2339,7 @@ namespace dpct
             library_data_t::real_bfloat16, library_data_t::real_bfloat16,
             library_data_t::real_float, library_data_t::real_float):
         {
-            detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
+            detail::gemm_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
             break;
         }
@@ -2405,7 +2378,7 @@ namespace dpct
             library_data_t::real_bfloat16, library_data_t::real_bfloat16,
             library_data_t::real_bfloat16, library_data_t::real_float):
         {
-            detail::gemm_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
+            detail::gemm_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
             break;
         }
@@ -2447,7 +2420,7 @@ namespace dpct
     /// \param [in] ldc Leading dimension of C.
     /// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
     /// \param [in] scaling_type Data type of the scaling factors.
-    inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
+    inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
                            int n, int k, const void * alpha, const void * a[], library_data_t a_type, int lda,
                            const void * b[], library_data_t b_type, int ldb, const void * beta, void * c[],
                            library_data_t c_type, int ldc, int batch_size, library_data_t scaling_type,
@@ -2485,7 +2458,7 @@ namespace dpct
             library_data_t::real_bfloat16, library_data_t::real_bfloat16,
             library_data_t::real_bfloat16, library_data_t::real_float):
         {
-            detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
+            detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
             break;
         }
@@ -2493,7 +2466,7 @@ namespace dpct
             library_data_t::real_bfloat16, library_data_t::real_bfloat16,
             library_data_t::real_float, library_data_t::real_float):
         {
-            detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
+            detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
                 q, a_trans, b_trans, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_size, matrix_info);
             break;
         }
@@ -2569,7 +2542,7 @@ namespace dpct
     /// \param [in] stride_c Stride between the different C matrices.
     /// \param [in] batch_size Specifies the number of matrix multiply operations to perform.
     /// \param [in] scaling_type Data type of the scaling factors.
-    inline void gemm_batch(sycl::queue & q, oneapi::math::transpose a_trans, oneapi::math::transpose b_trans, int m,
+    inline void gemm_batch(sycl::queue & q, oneapi::mkl::transpose a_trans, oneapi::mkl::transpose b_trans, int m,
                            int n, int k, const void * alpha, const void * a, library_data_t a_type, int lda,
                            long long int stride_a, const void * b, library_data_t b_type, int ldb,
                            long long int stride_b, const void * beta, void * c, library_data_t c_type, int ldc,
@@ -2642,7 +2615,7 @@ namespace dpct
             library_data_t::real_bfloat16, library_data_t::real_bfloat16,
             library_data_t::real_bfloat16, library_data_t::real_float):
         {
-            detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, oneapi::math::bfloat16, float>(
+            detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float>(
                 q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
                 batch_size);
             break;
@@ -2651,7 +2624,7 @@ namespace dpct
             library_data_t::real_bfloat16, library_data_t::real_bfloat16,
             library_data_t::real_float, library_data_t::real_float):
         {
-            detail::gemm_batch_impl<oneapi::math::bfloat16, oneapi::math::bfloat16, float, float>(
+            detail::gemm_batch_impl<oneapi::mkl::bfloat16, oneapi::mkl::bfloat16, float, float>(
                 q, a_trans, b_trans, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c,
                 batch_size);
             break;
index c5139fd3dd2d8c39b2b9e835e01efc55e1db3eeb..a03d26d7f200380fc997e00dbcf127a26568b308 100644 (file)
@@ -2167,8 +2167,8 @@ inline void ggml_sycl_op_mul_mat_sycl(
             const sycl::half alpha_f16 = 1.0f;
             const sycl::half beta_f16  = 0.0f;
             SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
-                *stream, oneapi::math::transpose::trans,
-                oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10,
+                *stream, oneapi::mkl::transpose::trans,
+                oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
                 &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
                 src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
                 dst_f16.get(), dpct::library_data_t::real_half, ldc,
@@ -2211,8 +2211,8 @@ inline void ggml_sycl_op_mul_mat_sycl(
         {
             const float alpha = 1.0f;
             const float beta  = 0.0f;
-            SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm(
-                get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff,
+            SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
+                *stream, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, row_diff,
                 src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
                 dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
         }
@@ -3165,8 +3165,8 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
             const int64_t smb = ne12 == 1 ? s13       : s12;
 
             // there is no broadcast and src0, src1 are contiguous across dims 2, 3
-            SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
-                                                        oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
+            SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::mkl::transpose::trans,
+                                                        oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
                                                         src0_f16, dpct::library_data_t::real_half, nb01 / nb00, sma,
                                                         src1_f16, dpct::library_data_t::real_half, s11, smb, beta, dst_ddf,
                                                         mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
@@ -3190,7 +3190,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
             });
 
             SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
-                *queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
+                *queue, oneapi::mkl::transpose::trans, oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
                 (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
                 (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
                 (void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
@@ -3524,12 +3524,11 @@ static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor
     use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
 #endif // SYCL_USE_XMX
 
-    // mmvq path is faster in the CUDA backend.
-    if (!g_ggml_sycl_prioritize_dmmv && (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda
-        // Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization
-        // is enabled takes precedence over DMMV, the current if-else implementation
-        // requires disabling DMMV if both conditions are met
-        || (should_reorder_tensor(ctx, dst) && ggml_sycl_supports_reorder_mmvq(src0->type)))) {
+    // Dispatch becomes obscure with the reorder, MMVQ when the reorder optimization
+    // is enabled takes precedence over DMMV, the current if-else implementation
+    // requires disabling DMMV if both conditions are met
+    if (!g_ggml_sycl_prioritize_dmmv && ((should_reorder_tensor(ctx, dst) &&
+                                          ggml_sycl_supports_reorder_mmvq(src0->type)))) {
         use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
     }
 
@@ -4189,16 +4188,6 @@ void ggml_backend_sycl_get_device_memory(int device, size_t *free,
     GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_memory\n");
     ggml_sycl_set_device(device);
 
-    /*
-    DPCT1009:218: SYCL uses exceptions to report errors and does not use the
-    error codes. The original code was commented out and a warning string was
-    inserted. You need to rewrite this code.
-    */
-    /*
-    DPCT1106:217: 'cudaMemGetInfo' was migrated with the Intel extensions for
-    device information which may not be supported by all compilers or runtimes.
-    You may need to adjust the code.
-    */
     SYCL_CHECK(CHECK_TRY_ERROR(
         dpct::dev_mgr::instance().get_device(device).get_memory_info(*free, *total)));
 }
index 3a17f3a1b88abf3c3b09d3f51f87752c1a02ab47..f52b11f0d6ed73ebce2f07a496653ab811619763 100644 (file)
@@ -32,12 +32,12 @@ void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, ggml_tensor* dst) {
 
     // Handle transposition of src1
     const bool src1_T = ggml_is_transposed(src1);
-    const oneapi::math::transpose src1_op = src1_T ? oneapi::math::transpose::nontrans : oneapi::math::transpose::trans;
+    const oneapi::mkl::transpose src1_op = src1_T ? oneapi::mkl::transpose::nontrans : oneapi::mkl::transpose::trans;
     const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
 
     try {
-        // Perform matrix multiplication using oneMath GEMM
-        oneapi::math::blas::column_major::gemm(get_onemath_backend(*stream), oneapi::math::transpose::nontrans, src1_op,
+        // Perform matrix multiplication using oneMKL GEMM
+        oneapi::mkl::blas::column_major::gemm(*stream, oneapi::mkl::transpose::nontrans, src1_op,
                                                ne0, ne1, ne01, alpha, src0_d, ne00, src1_d, ldb, beta, dst_d, ne0);
     }
     catch (sycl::exception const& exc) {
index 69140b19a4c072c0eb2eb20114aea5b1fa35910a..aeaa58b95b3281c5906182385321814da6e4ebe1 100644 (file)
@@ -207,7 +207,6 @@ static void rope_vision(const T * x, T * dst, const int ne0, const int ne1, cons
         const int p = sector;
         theta_base  = pos[channel_x] * sycl::pow(theta_scale, (float) p);
     } else {
-        // Simplified from CUDA backend code: if (sector >= sections.v[0] && sector < sec_w) which is just sector >= sections.v[0]
         const int p = sector - sections.v[0];
         theta_base  = pos[channel_x + ne2] * sycl::pow(theta_scale, (float) p);
     }
index c10e2f7645e89e045ca25e86a8598e734179ed26..b56e0c2400f4e2aa25d2a813c4157ae1c2d5c05a 100644 (file)
@@ -1,7 +1,7 @@
 #include <sycl/sycl.hpp>
 #include "wkv.hpp"
 
-constexpr int WKV_BLOCK_SIZE = 64;  // Matching CUDA_WKV_BLOCK_SIZE
+constexpr int WKV_BLOCK_SIZE = 64;
 
 // Helper function for the main kernel
 template <int block_size>