]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
sycl: Use syclcompat::dp4a (#10267)
authorRomain Biessy <redacted>
Fri, 15 Nov 2024 03:09:12 +0000 (04:09 +0100)
committerGitHub <redacted>
Fri, 15 Nov 2024 03:09:12 +0000 (11:09 +0800)
* sycl: Use syclcompat::dp4a

* Using the syclcompat version allow the compiler to optimize the
  operation with native function

* Update news section

* Update CI Windows oneAPI version to 2025.0

* Reword doc

* Call syclcompat::dp4a inside dpct::dp4a

This reverts commit 90cb61d692d61360b46954a1c7f780bd2e569b73.

.github/workflows/build.yml
docs/backend/SYCL.md
ggml/src/ggml-sycl/dpct/helper.hpp
ggml/src/ggml-sycl/vecdotq.hpp

index d6a7b66a511f84b7f7c089229e44da30874448a4..c770bbd155c4c3df23e45f132f95c32725c7d02d 100644 (file)
@@ -930,7 +930,7 @@ jobs:
         shell: bash
 
     env:
-      WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/7dff44ba-e3af-4448-841c-0d616c8da6e7/w_BaseKit_p_2024.1.0.595_offline.exe
+      WINDOWS_BASEKIT_URL: https://registrationcenter-download.intel.com/akdlm/IRC_NAS/b380d914-366b-4b77-a74a-05e3c38b3514/intel-oneapi-base-toolkit-2025.0.0.882_offline.exe
       WINDOWS_DPCPP_MKL: intel.oneapi.win.cpp-dpcpp-common:intel.oneapi.win.mkl.devel
       ONEAPI_ROOT: "C:/Program Files (x86)/Intel/oneAPI"
     steps:
index bc8c0f88647c206a912b1e80dae03e66c441f530..38185f73897ee9e435c23db2b6af0bc897f5baf3 100644 (file)
@@ -41,6 +41,8 @@ The following release is verified with good quality:
 
 ## News
 
+- 2024.11
+  - Use syclcompat to improve the performance on some platforms. This requires to use oneAPI 2025.0 or newer.
 
 - 2024.8
   - Use oneDNN as the default GEMM library, improve the compatibility for new Intel GPUs.
index fe4a8f744e2e03cbacb6b34f5c8338ed17bd142c..c2f28bb49579e9877cfe0042eafd95b2ef2055fe 100644 (file)
@@ -15,6 +15,7 @@
 
 #include <sycl/sycl.hpp>
 #include <sycl/half_type.hpp>
+#include <syclcompat/math.hpp>
 #include <oneapi/mkl.hpp>
 #include <map>
 
@@ -1830,31 +1831,10 @@ namespace dpct
                                            : id);
     }
 
-    template <typename T>
-    sycl::vec<T, 4> extract_and_sign_or_zero_extend4(T val)
-    {
-        return sycl::vec<T, 1>(val)
-            .template as<sycl::vec<
-                std::conditional_t<std::is_signed_v<T>, int8_t, uint8_t>, 4>>()
-            .template convert<T>();
-    }
-
-    template <typename T1, typename T2>
-    using dot_product_acc_t =
-        std::conditional_t<std::is_unsigned_v<T1> && std::is_unsigned_v<T2>,
-                           uint32_t, int32_t>;
-
     template <typename T1, typename T2, typename T3>
     inline auto dp4a(T1 a, T2 b, T3 c)
     {
-        dot_product_acc_t<T1, T2> res = c;
-        auto va = extract_and_sign_or_zero_extend4(a);
-        auto vb = extract_and_sign_or_zero_extend4(b);
-        res += va[0] * vb[0];
-        res += va[1] * vb[1];
-        res += va[2] * vb[2];
-        res += va[3] * vb[3];
-        return res;
+        return syclcompat::dp4a(a, b, c);
     }
 
     struct sub_sat
index d2dccade20bfd690b94bd4733acce8f546bf98ee..c5942008adfbdeaa6f43785d887a6babd4d6e3b6 100644 (file)
@@ -968,8 +968,8 @@ vec_dot_iq3_xxs_q8_1(const void *__restrict__ vbq,
             grid1[0] ^ signs[0], signs[0], std::minus<>());
         const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
             grid2[0] ^ signs[1], signs[1], std::minus<>());
-        sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi);
-        sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi);
+        sumi = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi);
+        sumi = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi);
         q8 += 8;
         aux32 >>= 7;
     }
@@ -1009,8 +1009,8 @@ vec_dot_iq3_s_q8_1(const void *__restrict__ vbq,
             grid1[0] ^ signs0, signs0, std::minus<>());
         const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
             grid2[0] ^ signs1, signs1, std::minus<>());
-        sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi);
-        sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi);
+        sumi = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi);
+        sumi = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi);
         q8 += 8;
     }
     const float d =