]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
supprt Flash Attention for fp32/fp16/Q4/Q5/Q8 (llama/20190)
authorNeo Zhang <redacted>
Sun, 8 Mar 2026 04:00:07 +0000 (12:00 +0800)
committerGeorgi Gerganov <redacted>
Mon, 16 Mar 2026 11:10:15 +0000 (13:10 +0200)
* support flash-attention for fp32/fp16/Q4/Q5/Q8

* rm warining

* update for JIT

62 files changed:
ggml/src/ggml-sycl/CMakeLists.txt
ggml/src/ggml-sycl/backend.hpp
ggml/src/ggml-sycl/common.hpp
ggml/src/ggml-sycl/convert.cpp
ggml/src/ggml-sycl/convert.hpp
ggml/src/ggml-sycl/count-equal.cpp
ggml/src/ggml-sycl/dpct/helper.hpp
ggml/src/ggml-sycl/fattn-common.hpp [new file with mode: 0644]
ggml/src/ggml-sycl/fattn-tile.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/fattn-tile.hpp [new file with mode: 0644]
ggml/src/ggml-sycl/fattn-vec.hpp [new file with mode: 0644]
ggml/src/ggml-sycl/fattn.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/fattn.hpp [new file with mode: 0644]
ggml/src/ggml-sycl/ggml-sycl.cpp
ggml/src/ggml-sycl/presets.hpp
ggml/src/ggml-sycl/softmax.cpp
ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp [new file with mode: 0644]
ggml/src/ggml-sycl/vecdotq.hpp

index eefdd9725ca8f977b57a8e27eceede4e65ffadda..7b07b227874fb01f17769e40ced9db7bc47072ac 100644 (file)
@@ -25,6 +25,11 @@ ggml_add_backend_library(ggml-sycl
 
 file(GLOB   GGML_HEADERS_SYCL "*.hpp")
 file(GLOB   GGML_SOURCES_SYCL "*.cpp")
+file(GLOB   SRCS "template-instances/fattn-tile*.cpp")
+list(APPEND GGML_SOURCES_SYCL ${SRCS})
+file(GLOB   SRCS "template-instances/fattn-vec*.cpp")
+list(APPEND GGML_SOURCES_SYCL ${SRCS})
+
 target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL})
 
 if (WIN32)
@@ -145,6 +150,7 @@ else()
 endif()
 
 if (GGML_SYCL_GRAPH)
+    message(STATUS "find GGML_SYCL_GRAPH")
     target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_GRAPH)
 endif()
 
index 75657f3fca2e7e2e7eec40c6fc173bb48a474fbd..b30b7f2beb74f177e54bb115ef81565e91720e47 100644 (file)
@@ -23,6 +23,7 @@
 #include "dequantize.hpp"
 #include "dmmv.hpp"
 #include "element_wise.hpp"
+#include "fattn.hpp"
 #include "gla.hpp"
 #include "im2col.hpp"
 #include "mmq.hpp"
index 04c9e1d786452f68c33ebfd8554c41114a997bf7..298fddc103875eab1029ed1e7cbd7d55d56a054b 100644 (file)
 #include <string>
 
 #include "dpct/helper.hpp"
+#include "ggml.h"
+#include "ggml-impl.h"
 #include "ggml-sycl.h"
 #include "presets.hpp"
 #include "sycl_hw.hpp"
 
+namespace syclexp = sycl::ext::oneapi::experimental;
 
 #if GGML_SYCL_DNNL
 #include "dnnl.hpp"
@@ -31,6 +34,9 @@
 
 #define GGML_COMMON_DECL_SYCL
 #define GGML_COMMON_IMPL_SYCL
+#define SYCL_FLASH_ATTN //remove it to disable FLASH_ATTENTION in building.
+#define SYCL_FAST_FP16  //don't change. remove it will break fattn-tile.hpp building
+
 /* suppress warning spam */
 #pragma clang diagnostic push
 #pragma clang diagnostic ignored "-Wnested-anon-types"
@@ -45,6 +51,8 @@ void ggml_sycl_host_free(void* ptr);
 extern int g_ggml_sycl_debug;
 extern int g_ggml_sycl_disable_optimize;
 extern int g_ggml_sycl_prioritize_dmmv;
+extern int g_ggml_sycl_enable_flash_attention;
+
 
 #if defined(__clang__) && __has_builtin(__builtin_expect)
 // Hint the optimizer to pipeline the more likely following instruction in branches
@@ -170,6 +178,10 @@ static size_t g_scratch_offset = 0;
 
 int get_current_device_id();
 
+inline int ggml_sycl_get_device() {
+    return get_current_device_id();
+}
+
 inline dpct::err0 ggml_sycl_set_device(const int device) try {
   int current_device_id;
   SYCL_CHECK(CHECK_TRY_ERROR(current_device_id = get_current_device_id()));
@@ -194,11 +206,14 @@ struct optimize_feature {
 };
 
 struct sycl_device_info {
-    int     cc;                 // compute capability
+    int cc;  // compute capability
     int nsm; // number of streaming multiprocessors (CUDA) maps to the maximum
              // number of compute units on a SYCL device.
     // size_t  smpb;               // max. shared memory per block
     size_t  smpbo;              // max. shared memory per block (with opt-in)
+    int warp_size;     // max sub_group_size of SYCL
+    int max_wg_per_cu; // max work groups per compute unit - refer to
+                       // cudaOccupancyMaxActiveBlocksPerMultiprocessor
     bool    vmm;                // virtual memory support
     size_t  total_vram;
     //sycl_hw_info hw_info;     \\ device id and aarch, currently not used
@@ -435,13 +450,15 @@ warp_reduce_sum(sycl::float2 a, const sycl::nd_item<3>& item_ct1) {
     return a;
 }
 
-template <int width = WARP_SIZE>
+/* use WARP_SIZE or WARP_32_SIZE*/
+template <int width>
 static __dpct_inline__ int warp_reduce_sum(int x) {
   return sycl::reduce_over_group(
       sycl::ext::oneapi::this_work_item::get_sub_group(), x, sycl::plus<>());
 }
 
-template <int width = WARP_SIZE>
+/* use WARP_SIZE or WARP_32_SIZE*/
+template <int width>
 static __dpct_inline__ float warp_reduce_sum(float x) {
 #pragma unroll
   for (int offset = width / 2; offset > 0; offset >>= 1) {
@@ -451,7 +468,19 @@ static __dpct_inline__ float warp_reduce_sum(float x) {
   return x;
 }
 
-template <int width = WARP_SIZE>
+/* use WARP_SIZE or WARP_32_SIZE*/
+template <int width>
+static __dpct_inline__ float warp_reduce_sum(float x, const sycl::nd_item<3>& item_ct1) {
+#pragma unroll
+  for (int offset = width / 2; offset > 0; offset >>= 1) {
+    x += dpct::permute_sub_group_by_xor(
+        item_ct1.get_sub_group(), x, offset);
+  }
+  return x;
+}
+
+/* use WARP_SIZE or WARP_32_SIZE*/
+template <int width>
 static __dpct_inline__ sycl::float2 warp_reduce_sum(sycl::float2 a) {
 #pragma unroll
   for (int offset = width / 2; offset > 0; offset >>= 1) {
@@ -465,7 +494,8 @@ static __dpct_inline__ sycl::float2 warp_reduce_sum(sycl::float2 a) {
   return a;
 }
 
-template <int width = WARP_SIZE>
+/* use WARP_SIZE or WARP_32_SIZE*/
+template <int width>
 static __dpct_inline__ sycl::half2 warp_reduce_sum(sycl::half2 a) {
 #pragma unroll
   for (int offset = width / 2; offset > 0; offset >>= 1) {
@@ -481,7 +511,52 @@ static constexpr int ggml_sycl_get_physical_warp_size() {
   return WARP_SIZE;
 }
 
-template <int width = WARP_SIZE>
+/* use WARP_SIZE or WARP_32_SIZE*/
+template <int width>
+static __dpct_inline__ int warp_reduce_all(int x) {
+    if (width == ggml_sycl_get_physical_warp_size()) {
+        return sycl::all_of_group(
+            sycl::ext::oneapi::this_work_item::get_sub_group(),
+            (~0xffffffff &
+             (0x1 << sycl::ext::oneapi::this_work_item::get_sub_group()
+                         .get_local_linear_id())) ||
+                x);
+    } else {
+#pragma unroll
+        for (int offset = width / 2; offset > 0; offset >>= 1) {
+            x = dpct::permute_sub_group_by_xor(
+                    sycl::ext::oneapi::this_work_item::get_sub_group(), x,
+                    offset, width) &&
+                x;
+        }
+        return x;
+    }
+}
+
+/* use WARP_SIZE or WARP_32_SIZE*/
+template <int width>
+static __dpct_inline__ int warp_reduce_any(int x) {
+    if (width == ggml_sycl_get_physical_warp_size()) {
+        return sycl::any_of_group(
+            sycl::ext::oneapi::this_work_item::get_sub_group(),
+            (0xffffffff &
+             (0x1 << sycl::ext::oneapi::this_work_item::get_sub_group()
+                         .get_local_linear_id())) &&
+                x);
+    } else {
+#pragma unroll
+        for (int offset = width / 2; offset > 0; offset >>= 1) {
+            x = dpct::permute_sub_group_by_xor(
+                    sycl::ext::oneapi::this_work_item::get_sub_group(), x,
+                    offset, width) ||
+                x;
+        }
+        return x;
+    }
+}
+
+/* use WARP_SIZE or WARP_32_SIZE*/
+template <int width>
 static __dpct_inline__ float warp_reduce_max(float x) {
 #pragma unroll
   for (int offset = width / 2; offset > 0; offset >>= 1) {
@@ -629,6 +704,42 @@ static const sycl::uint3 init_fastdiv_values(uint32_t d) {
     return sycl::uint3(mp, L, d);
 }
 
+// Maximum number of bytes that can be copied in a single instruction.
+// Set by test result.
+static constexpr int ggml_sycl_get_max_cpy_bytes() {
+    return 16;
+}
+
+// Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes.
+template <int nbytes, int alignment = 0>
+static __dpct_inline__ void ggml_sycl_memcpy_1(void * dst, const void * src) {
+    if constexpr (alignment != 0) {
+        static_assert(nbytes % alignment == 0, "bad alignment");
+    }
+    constexpr int nb_per_cpy = alignment == 0 ? nbytes : alignment;
+
+#pragma unroll
+    for (int i = 0; i < nbytes/nb_per_cpy; ++i) {
+        if constexpr (nb_per_cpy == 1) {
+            ((char *) dst)[i] = ((const char *) src)[i];
+        } else if constexpr (nb_per_cpy == 2) {
+            ((short *) dst)[i] = ((const short *) src)[i];
+        } else if constexpr (nb_per_cpy == 4) {
+            ((int *) dst)[i] = ((const int *) src)[i];
+        } else if constexpr (nb_per_cpy == 8) {
+            ((sycl::int2 *) dst)[i] = ((const sycl::int2 *) src)[i];
+        } else if constexpr (nb_per_cpy == 16) {
+            ((sycl::int4 *) dst)[i] = ((const sycl::int4 *) src)[i];
+        } else {
+            static_assert(nbytes == 0 && nbytes == -1, "bad nbytes");
+        }
+    }
+}
+template <typename T>
+sycl::half2 __dpct_inline__ make_half2( T x, T y) {
+    sycl::half2 res(static_cast<sycl::half>(x),static_cast<sycl::half>(y));
+    return res;
+}
 
 static __dpct_inline__ uint32_t fastdiv(uint32_t n, const sycl::uint3 fastdiv_values) {
     const uint32_t hi = sycl::mul_hi<unsigned>(n, fastdiv_values.x());
@@ -636,6 +747,17 @@ static __dpct_inline__ uint32_t fastdiv(uint32_t n, const sycl::uint3 fastdiv_va
 }
 
 
+template <typename T>
+sycl::float2 __dpct_inline__ make_float2( T x, T y) {
+    sycl::float2 res(static_cast<float>(x),static_cast<float>(y));
+    return res;
+}
+
+sycl::float2 __dpct_inline__ __half22float2(sycl::half2 &H) {
+    sycl::float2 float2_value(static_cast<float>(H.x()), static_cast<float>(H.y()));
+    return float2_value;
+}
+
 static __dpct_inline__ sycl::uint2 fast_div_modulo(uint32_t n, const sycl::uint3 fastdiv_values) {
     const uint32_t div_val = fastdiv(n, fastdiv_values);
     const uint32_t mod_val = n - div_val * fastdiv_values.z();
@@ -659,5 +781,97 @@ static __dpct_inline__ float ggml_sycl_e8m0_to_fp32(uint8_t x) {
     return result;
 }
 
+sycl::float2 __dpct_inline__ __half22float2(const sycl::half2 &H) {
+    sycl::float2 float2_value(static_cast<float>(H.x()), static_cast<float>(H.y()));
+    return float2_value;
+}
+
+float __dpct_inline__ __half2float(sycl::half H) {
+    return static_cast<float>(H);
+}
+
+static __dpct_inline__ void ggml_sycl_mad(float & acc, const float v, const float u) {
+    acc += v*u;
+}
+
+static __dpct_inline__ void ggml_sycl_mad(float & acc, const sycl::float2 v, const sycl::float2 u) {
+    acc += v.x() * u.x();
+    acc += v.y() * u.y();
+}
+
+static __dpct_inline__ void ggml_sycl_mad(float & acc, const sycl::half2 v, const sycl::half2 u) {
+#ifdef GGML_SYCL_F16
+    const sycl::float2 tmp = (v * u).template convert<float, sycl::rounding_mode::automatic>();
+    acc += tmp.x() + tmp.y();
+#else
+    const sycl::float2 tmpv = __half22float2(v);
+    const sycl::float2 tmpu = __half22float2(u);
+    acc += tmpv.x() * tmpu.x();
+    acc += tmpv.y() * tmpu.y();
+#endif // GGML_SYCL_F16
+}
+
+static __dpct_inline__ void ggml_sycl_mad(sycl::half2 & acc, const sycl::half2 v, const sycl::half2 u) {
+#ifdef GGML_SYCL_F16
+    acc += v*u;
+#else
+    const sycl::float2 tmpv = __half22float2(v);
+    const sycl::float2 tmpu = __half22float2(u);
+    sycl::float2 tmpacc = __half22float2(acc);
+    // tmpacc.x += tmpv.x() * tmpu.x();
+    // tmpacc.y += tmpv.y() * tmpu.y();
+    sycl::float2 tmp1(tmpacc.x() + tmpv.x() * tmpu.x(), tmpacc.y() + tmpv.y() * tmpu.y());
+    acc = make_half2(tmp1.x(), tmp1.y());
+#endif // GGML_SYCL_F16
+}
+
+template <int n>
+struct ggml_sycl_unroll {
+    template <typename Func, typename... Args>
+    void operator()(const Func & f, Args... args) const {
+        f(n - 1, args...);
+        ggml_sycl_unroll<n - 1>{}(f, args...);
+    }
+};
+
+template <>
+struct ggml_sycl_unroll<1> {
+    template <typename Func, typename... Args>
+    void operator()(const Func & f, Args... args) const {
+        f(0, args...);
+    }
+};
+
+static __dpct_inline__ sycl::half2 ggml_sycl_hmax2(const sycl::half2 a, const sycl::half2 b) {
+    sycl::half2 ret;
+    reinterpret_cast<sycl::half &>(ret.x()) =
+        sycl::vec<float, 1>(sycl::fmax(a[0], b[0])).convert<sycl::half, sycl::rounding_mode::automatic>()[0];
+    reinterpret_cast<sycl::half &>(ret.y()) =
+        sycl::vec<float, 1>(sycl::fmax(a[1], b[1])).convert<sycl::half, sycl::rounding_mode::automatic>()[0];
+    return ret;
+}
+
+static __dpct_inline__ sycl::half ggml_sycl_hmax(const sycl::half a, const sycl::half b) {
+    return sycl::vec<float, 1>(
+               sycl::fmax(sycl::vec<sycl::half, 1>(a).convert<float, sycl::rounding_mode::automatic>()[0],
+                          sycl::vec<sycl::half, 1>(b).convert<float, sycl::rounding_mode::automatic>()[0]))
+        .convert<sycl::half, sycl::rounding_mode::automatic>()[0];
+}
+
+static __dpct_inline__ uint32_t __hgt2_mask(const sycl::half2 a, const sycl::half2 b) {
+    const uint32_t mask_low  = 0x0000FFFF * (float(a[0]) > float(b[0]));
+    const uint32_t mask_high = 0xFFFF0000 * (float(a[1]) > float(b[1]));
+    return mask_low | mask_high;
+}
+
+static __dpct_inline__ uint32_t fastmodulo(uint32_t n, const sycl::uint3 fastdiv_values) {
+    // expects  fastdiv_values to contain <mp, L, divisor> in <x, y, z> (see init_fastdiv_values)
+    return n - fastdiv(n, fastdiv_values) * fastdiv_values.z();
+}
+
+static bool fast_fp16_available(const int cc) {
+    GGML_UNUSED(cc);
+    return true;   //Intel GPUs always support FP16.
+}
 
 #endif // GGML_SYCL_COMMON_HPP
index 8bdae36458ce79ecb7a3edbb6755d1d44c315367..d17aca2cac4eac239965805f802879fbf5b857b9 100644 (file)
@@ -482,6 +482,63 @@ static void dequantize_row_mxfp4_sycl(const void * vx, dst_t * y, const int64_t
         });
 }
 
+template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
+static void dequantize_block_nc(const void * __restrict__ vx, dst_t * __restrict__ y,
+        const int64_t ne00, const int64_t ne01, const int64_t ne02,
+        const int64_t s01, const int64_t s02, const int64_t s03) {
+    auto          item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    const int64_t i00 = 2 * (int64_t(item_ct1.get_local_range(2)) * item_ct1.get_group(2) + item_ct1.get_local_id(2));
+
+    if (i00 >= ne00) {
+        return;
+    }
+
+    const int64_t i01 = item_ct1.get_group(1);
+    const int64_t i02 = item_ct1.get_group(0) % ne02;
+    const int64_t i03 = item_ct1.get_group(0) / ne02;
+
+    const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
+
+    const int64_t ib = ibx0 + i00/qk; // block index
+    const int64_t iqs = (i00%qk)/qr; // quant index
+    const int64_t iybs = i00 - i00%qk; // y block start index
+    const int64_t y_offset = qr == 1 ? 1 : qk/2;
+
+    // dequantize
+    #ifdef GGML_SYCL_F16
+        sycl::half2 v;
+    #else
+        sycl::float2 v;
+    #endif
+
+    dequantize_kernel(vx, ib, iqs, v);
+
+    const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs;
+    y[iy0 + 0]        = ggml_sycl_cast<dst_t>(v.x());
+    y[iy0 + y_offset] = ggml_sycl_cast<dst_t>(v.y());
+}
+
+
+template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
+static void dequantize_block_nc_sycl(const void *    vx,
+                                  dst_t *         y,
+                                  const int64_t   ne00,
+                                  const int64_t   ne01,
+                                  const int64_t   ne02,
+                                  const int64_t   ne03,
+                                  const int64_t   s01,
+                                  const int64_t   s02,
+                                  const int64_t   s03,
+                                  dpct::queue_ptr stream) {
+    const dpct::dim3 num_blocks((ne00 + 2 * SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2 * SYCL_DEQUANTIZE_BLOCK_SIZE), ne01,
+                                ne02 * ne03);
+    stream->parallel_for(sycl::nd_range<3>(num_blocks * sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
+                                           sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),
+                         [=](sycl::nd_item<3> item_ct1) {
+                             GGML_UNUSED(item_ct1);
+                             dequantize_block_nc<qk, qr, dequantize_kernel>(vx, y, ne00, ne01, ne02, s01, s02, s03);
+                         });
+}
 template <typename src_t, typename dst_t>
 static void convert_unary_nc(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01,
                           const int64_t ne02, const int64_t s01, const int64_t s02, const int64_t s03,
@@ -662,7 +719,8 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
     }
 }
 
-to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type) {
+
+to_fp16_nc_sycl_t ggml_get_to_fp16_nc_sycl(ggml_type type) {
     switch (type) {
         case GGML_TYPE_F32:
             return convert_unary_nc_sycl<float>;
@@ -670,6 +728,16 @@ to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type) {
         case GGML_TYPE_BF16:
             return convert_unary_nc_sycl<sycl::ext::oneapi::bfloat16>;
 #endif
+        case GGML_TYPE_Q4_0:
+            return dequantize_block_nc_sycl<QK4_0, QR4_0, dequantize_q4_0>;
+        case GGML_TYPE_Q4_1:
+            return dequantize_block_nc_sycl<QK4_1, QR4_1, dequantize_q4_1>;
+        case GGML_TYPE_Q5_0:
+            return dequantize_block_nc_sycl<QK5_0, QR5_0, dequantize_q5_0>;
+        case GGML_TYPE_Q5_1:
+            return dequantize_block_nc_sycl<QK5_1, QR5_1, dequantize_q5_1>;
+        case GGML_TYPE_Q8_0:
+            return dequantize_block_nc_sycl<QK8_0, QR8_0, dequantize_q8_0>;
         default:
             return nullptr;
     }
index f8cb573e3688bc470aecac9f5bbddf232a1c028b..f93bd0df7d7e40240814f61456a256df2288524c 100644 (file)
@@ -29,6 +29,21 @@ using to_t_nc_sycl_t = void (*)(const void * x, T * y, int64_t ne00, int64_t ne0
                                    int64_t s01, int64_t s02, int64_t s03, dpct::queue_ptr queue);
 
 typedef to_t_nc_sycl_t<sycl::half> to_fp16_nc_sycl_t;
-to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type);
+to_fp16_nc_sycl_t ggml_get_to_fp16_nc_sycl(ggml_type type);
+
+template<typename dst_t, typename src_t>
+ inline dst_t ggml_sycl_cast(src_t x) {
+    if constexpr (std::is_same_v<dst_t, src_t>) {
+        return x;
+    } else if constexpr (std::is_same_v<dst_t, sycl::ext::oneapi::bfloat16>) {
+        return sycl::ext::oneapi::bfloat16(float(x));
+    } else if constexpr (std::is_same_v<src_t, sycl::ext::oneapi::bfloat16>) {
+        return static_cast<float>(x);
+    } else if constexpr(std::is_same_v<dst_t, int32_t>) {
+        return int32_t(x);
+    } else {
+        return float(x);
+    }
+}
 
 #endif  // GGML_SYCL_CONVERT_HPP
index b0a8b4820de22058707f6ea215b18263399e648d..4580354cd9d00b05671d632ae79f6ceacd96b945 100644 (file)
@@ -18,7 +18,7 @@ static void count_equal(const T *__restrict__ x, const T *__restrict__ y,
         nequal += xi == yi;
     }
 
-    nequal = warp_reduce_sum(nequal);
+    nequal = warp_reduce_sum<WARP_SIZE>(nequal);
 
     if (item_ct1.get_local_id(2) != 0) {
         return;
index ece66a7ac1f438ee7a58f75c0953447a0271a6f5..791d3cac52e11c311f29e91169c9b0f68ccc2699 100644 (file)
@@ -2997,6 +2997,778 @@ namespace dpct
       return 0;
     }
 
+    template <int n_nondefault_params, int n_default_params, typename T>
+    class args_selector;
+
+    /// args_selector is a helper class for extracting arguments from an
+    /// array of pointers to arguments or buffer of arguments to pass to a
+    /// kernel function.
+    ///
+    /// \param R(Ts...) The type of the kernel
+    /// \param n_nondefault_params The number of nondefault parameters of the
+    /// kernel (excluding parameters that like sycl::nd_item, etc.) \param
+    /// n_default_params The number of default parameters of the kernel
+    ///
+    /// Example usage:
+    /// With the following kernel:
+    ///   void foo(sycl::float2 *x, int n, sycl::nd_item<3> item_ct1, float
+    ///   f=.1) {}
+    /// and with the declaration:
+    ///   args_selector<2, 1, decltype(foo)> selector(kernelParams, extra);
+    /// we have:
+    ///   selector.get<0>() returns a reference to sycl::float*,
+    ///   selector.get<1>() returns a reference to int,
+    ///   selector.get<2>() returns a reference to float
+    template <int n_nondefault_params, int n_default_params, typename R,
+              typename... Ts>
+    class args_selector<n_nondefault_params, n_default_params, R(Ts...)> {
+      private:
+        void **kernel_params;
+        char *args_buffer;
+
+        template <int i> static constexpr int account_for_default_params() {
+            constexpr int n_total_params = sizeof...(Ts);
+            if constexpr (i >= n_nondefault_params) {
+                return n_total_params - n_default_params +
+                       (i - n_nondefault_params);
+            } else {
+                return i;
+            }
+        }
+
+      public:
+        /// Get the type of the ith argument of R(Ts...)
+        /// \param [in] i Index of parameter to get
+        /// \returns Type of ith parameter
+        template <int i>
+        using arg_type = std::tuple_element_t<account_for_default_params<i>(),
+                                              std::tuple<Ts...>>;
+        static constexpr int params_num = sizeof...(Ts);
+
+      private:
+        template <int i> static constexpr int get_offset() {
+            if constexpr (i == 0) {
+                // we can assume args_buffer is properly aligned to the
+                // first argument
+                return 0;
+            } else {
+                constexpr int prev_off = get_offset<i - 1>();
+                constexpr int prev_past_end =
+                    prev_off + sizeof(arg_type<i - 1>);
+                using T = arg_type<i>;
+                // is the past-the-end of the i-1st element properly aligned
+                // with the ith element's alignment?
+                if constexpr (prev_past_end % alignof(T) == 0) {
+                    return prev_past_end;
+                }
+                // otherwise bump prev_past_end to match alignment
+                else {
+                    return prev_past_end +
+                           (alignof(T) - (prev_past_end % alignof(T)));
+                }
+            }
+        }
+
+        static char *get_args_buffer(void **extra) {
+            if (!extra)
+                return nullptr;
+            for (; (std::size_t)*extra != 0; ++extra) {
+                if ((std::size_t)*extra == 1) {
+                    return static_cast<char *>(*(extra + 1));
+                }
+            }
+            return nullptr;
+        }
+
+      public:
+        /// If kernel_params is nonnull, then args_selector will
+        /// extract arguments from kernel_params. Otherwise, it
+        /// will extract them from extra.
+        /// \param [in] kernel_params Array of pointers to arguments
+        /// a or null pointer.
+        /// \param [in] extra Array containing pointer to argument buffer.
+        args_selector(void **kernel_params, void **extra)
+            : kernel_params(kernel_params),
+              args_buffer(get_args_buffer(extra)) {}
+
+        /// Get a reference to the ith argument extracted from kernel_params
+        /// or extra.
+        /// \param [in] i Index of argument to get
+        /// \returns Reference to the ith argument
+        template <int i> arg_type<i> &get() {
+            if (kernel_params) {
+                return *static_cast<arg_type<i> *>(kernel_params[i]);
+            } else {
+                return *reinterpret_cast<arg_type<i> *>(args_buffer +
+                                                        get_offset<i>());
+            }
+        }
+    }; // COPY from DPCT head file
+       // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/util.hpp
+
+    /// Utility class for launching SYCL kernels through kernel
+    /// function wrapper.
+    /// For example:
+    /// A SYCL kernel function:
+    ///   void kernel_func(int *ptr, sycl::nd_item<3> item);
+    /// Kernel function wrapper:
+    ///   void kernel_func_wrapper(int *ptr) {
+    ///     sycl::queue queue = *dpct::kernel_launcher::_que;
+    ///     unsigned int localMemSize = dpct::kernel_launcher::_local_mem_size;
+    ///     sycl::nd_range<3> nr = dpct::kernel_launcher::_nr;
+    ///     queue.parallel_for(
+    ///       nr,
+    ///       [=](sycl::nd_item<3> item_ct1) {
+    ///         kernel_func(ptr, item_ct1);
+    ///       });
+    ///   }
+    /// Then launch the kernel through wrapper like:
+    ///   typedef void(*fpt)(int *);
+    ///   fpt fp = kernel_func_wrapper;
+    ///   dpct::kernel_launcher::launch(fp, dpct::dim3(1), dpct::dim3(1), 0, 0,
+    ///   device_ptr);
+    /// If the origin function type is erased, then need to register it first:
+    ///   void *fp = (void *)wrapper_register(&kernel_func_wrapper).get();
+    ///   dpct::kernel_launcher::launch(fp, dpct::dim3(1), dpct::dim3(1), args,
+    ///   0, 0);
+    class kernel_launcher {
+        template <typename FuncT, typename ArgSelector, std::size_t... Index>
+        static void launch_helper(FuncT &&func, ArgSelector &selector,
+                                  std::index_sequence<Index...>) {
+            func(selector.template get<Index>()...);
+        }
+        static void set_execution_config(dim3 group_range, dim3 local_range,
+                                         unsigned int local_mem_size,
+                                         queue_ptr que) {
+            if (que) {
+                _que = que;
+            } else {
+                _que = &get_default_queue();
+            }
+            _nr = sycl::nd_range<3>(
+                static_cast<sycl::range<3>>(group_range * local_range),
+                static_cast<sycl::range<3>>(local_range));
+            _local_mem_size = local_mem_size;
+
+
+        };
+        static inline std::mutex kernel_function_ptr_map_mutex;
+
+      public:
+        /// Variables for storing execution configuration.
+        static inline thread_local sycl::queue *_que = nullptr;
+        static inline thread_local sycl::nd_range<3> _nr = sycl::nd_range<3>();
+        static inline thread_local unsigned int _local_mem_size = 0;
+        /// Map for retrieving launchable functor from a raw pointer.
+        static inline std::map<
+            const void *,
+            std::function<void(dim3, dim3, void **, unsigned int, queue_ptr)>>
+            kernel_function_ptr_map = {};
+
+        /// Registers a kernel function pointer with a corresponding launchable
+        /// functor.
+        /// \param [in] func Pointer to the kernel function.
+        /// \param [in] launcher Functor to handle kernel invocation.
+        static void register_kernel_ptr(
+            const void *func,
+            std::function<void(dim3, dim3, void **, unsigned int, queue_ptr)>
+                launcher) {
+            std::lock_guard<std::mutex> lock(kernel_function_ptr_map_mutex);
+            kernel_function_ptr_map[func] = std::move(launcher);
+        }
+        /// Launches a kernel function with arguments provided directly through
+        /// kernel function wrapper.
+        /// \tparam FuncT Type of the kernel function wrapper.
+        /// \tparam ArgsT Types of kernel arguments.
+        /// \param [in] func Pointer to the kernel function wrapper.
+        /// \param [in] group_range SYCL group range.
+        /// \param [in] local_range SYCL local range.
+        /// \param [in] local_mem_size The size of local memory required by the
+        /// kernel function. \param [in] que SYCL queue used to execute kernel.
+        /// \param [in] args Kernel arguments.
+        template <typename FuncT, typename... ArgsT>
+        static std::enable_if_t<std::is_invocable_v<FuncT *, ArgsT...>, void>
+        launch(FuncT *func, dim3 group_range, dim3 local_range,
+               unsigned int local_mem_size, queue_ptr que, ArgsT... args) {
+            set_execution_config(group_range, local_range, local_mem_size, que);
+            func(args...);
+        }
+        /// Launches a kernel function through registered kernel function
+        /// wrapper. \param [in] func Pointer to the registered kernel function
+        /// wrapper. \param [in] group_range SYCL group range. \param [in]
+        /// local_range SYCL local range. \param [in] args Array of pointers to
+        /// kernel arguments. \param [in] local_mem_size The size of local
+        /// memory required by the kernel function. \param [in] que SYCL queue
+        /// used to execute kernel.
+        static void launch(const void *func, dim3 group_range, dim3 local_range,
+                           void **args, unsigned int local_mem_size,
+                           queue_ptr que) {
+            std::lock_guard<std::mutex> lock(kernel_function_ptr_map_mutex);
+            auto Iter = kernel_function_ptr_map.find(func);
+            if (Iter == kernel_function_ptr_map.end()) {
+                throw std::runtime_error("dpct::launch() : no registered "
+                                         "kernel function wrapper found.");
+            }
+            (Iter->second)(group_range, local_range, args, local_mem_size, que);
+        }
+        /// Launches a kernel function with packed arguments through kernel
+        /// function wrapper.
+        /// \tparam FuncT Type of the kernel function wrapper.
+        /// \param [in] func Pointer to the kernel function wrapper.
+        /// \param [in] group_range SYCL group range.
+        /// \param [in] local_range SYCL local range.
+        /// \param [in] args Array of pointers to kernel arguments.
+        /// \param [in] local_mem_size The size of local memory required by the
+        /// kernel function. \param [in] que SYCL queue used to execute kernel.
+        template <typename FuncT>
+        static std::enable_if_t<std::is_function_v<FuncT>, void>
+        launch(FuncT *func, dim3 group_range, dim3 local_range, void **args,
+               unsigned int local_mem_size, queue_ptr que) {
+            constexpr size_t p_num = args_selector<0, 0, FuncT>::params_num;
+            set_execution_config(group_range, local_range, local_mem_size, que);
+            args_selector<p_num, p_num, FuncT> selector(args, nullptr);
+            launch_helper(func, selector, std::make_index_sequence<p_num>{});
+        }
+    }; // COPY from DPCT head file
+       // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/kernel.hpp
+
+    // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/util.hpp
+    template <typename T>
+    T select_from_sub_group(
+        sycl::sub_group g,
+        T x,
+        int remote_local_id,
+        int logical_sub_group_size = 32) {
+      unsigned int start_index = g.get_local_linear_id() /
+                                 logical_sub_group_size *
+                                 logical_sub_group_size;
+      return sycl::select_from_group(
+          g, x, start_index + remote_local_id % logical_sub_group_size);
+    }
+
+    // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/math.hpp
+    template <typename T>
+    void ldmatrix(uintptr_t addr, T* m, bool trans = false, unsigned mat = 0) {
+      auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();
+      int lane = sg.get_local_linear_id();
+
+      int lane_group8_row = lane / 8;
+      int lane_group8_col = lane % 8;
+
+      if (!trans) {
+        // calculate the source lane
+        int src_lane = 2 * lane_group8_row;
+        if (lane_group8_col >= 4)
+          src_lane += 1;
+
+        // Broadcast the address from the source lane
+        auto recv_addr_uintp =
+            dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane);
+
+        // Cast the received address from uintptr_t to the type of 'm'
+        auto recv_addr = reinterpret_cast<T*>(recv_addr_uintp);
+
+        // Non-transposed load
+        *m = recv_addr[lane_group8_col % 4];
+      } else {
+        // calculate the source lane
+        int src_lane = (lane % 4) * 2;
+
+        // Broadcast the address from the source lane
+        auto recv_addr_uintp_1 =
+            dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane);
+        auto recv_addr_uintp_2 =
+            dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane + 1);
+
+        // Cast the received address from uintptr_t to 'half *'
+        auto recv_addr_1 = reinterpret_cast<sycl::half*>(recv_addr_uintp_1);
+        auto recv_addr_2 = reinterpret_cast<sycl::half*>(recv_addr_uintp_2);
+
+        // Transposed load
+        int index = lane / 4;
+        sycl::half val0 = recv_addr_1[index];
+        sycl::half val1 = recv_addr_2[index];
+
+        // Combine the two 16-bits into one 32-bit value
+        sycl::half2 val = sycl::half2(val0, val1);
+        *m = *reinterpret_cast<T*>(&val);
+      }
+    }
+
+    template <typename T>
+    void ldmatrix(uintptr_t addr, T* m1, T* m2, bool trans = false) {
+      // Load 1st matrix
+      ldmatrix(addr, m1, trans, 0);
+      // Load 2nd matrix
+      ldmatrix(addr, m2, trans, 1);
+    }
+
+    template <typename T>
+    void ldmatrix(
+        uintptr_t addr, T* m1, T* m2, T* m3, T* m4, bool trans = false) {
+      // Load 1st matrix
+      ldmatrix(addr, m1, trans, 0);
+      // Load 2nd matrix
+      ldmatrix(addr, m2, trans, 1);
+      // Load 3rd matrix
+      ldmatrix(addr, m3, trans, 2);
+      // Load 4th matrix
+      ldmatrix(addr, m4, trans, 3);
+    }
+
+    // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/math.hpp
+
+    /// A helper struct that defines the pack type for the input matrix
+    /// fragments
+    /// of mma() function based on the type of input matrix fragments.
+    /// The MMAType struct is specialized for different types of input matrices.
+    /// Currently, the specialization for f16, bf16 and s8 types is defined
+    /// below. \tparam [in] T The type of the input matrix fragments
+    template <typename T>
+    struct MMAType {
+      using PackType = uint32_t;
+    };
+
+    /// Each work item of a sub-group (limited to size 32) calling this function
+    /// calculates a subset fragment for the output matrix D using MAD operation
+    /// on A, B & C matrix fragments (D = A * B + C). Current supported shapes &
+    /// types:
+    /// - m8n8k4 (f32.f16.f16.f32)
+    /// - m8n8k16 (s32.s8.s8.s32)
+    /// - m16n8k8 (f32.f16.f16.f32 & f32.bf16.bf16.f32)
+    /// - m16n8k16 (f32.f16.f16.f32 & s32.s8.s8.s32)
+    /// - m16n8k32 (s32.s8.s8.s32)
+    /// Here, m, n & k define the shapes of A, B & C matrices respectively
+    /// (A = [m x k], B = [k x n], C = [m x n]).
+    /// \tparam [in] M The rows of A, C & D matrices
+    /// \tparam [in] N The columns of B, C, D matrices
+    /// \tparam [in] K The columns & rows of A & B matrices respectively
+    /// \tparam [in] ABType The type of the input matrix (A & B) fragment
+    /// \tparam [in] CDType The type of the output matrix (C & D) fragment
+    /// \param [out] d_mat_frag The fragment of the output matrix D to store the
+    /// result of A * B + C
+    /// \param [in] a_mat_frag The fragment of the input matrix A to be
+    /// multiplied with B matrix fragment \param [in] b_mat_frag The fragment of
+    /// the input matrix B to be multiplied with A matrix fragment \param [in]
+    /// c_mat_frag The fragment of the input matrix C to be added with the
+    /// result of A * B fragments
+    template <int M, int N, int K, typename ABType, typename CDType>
+    void mma(
+        volatile void** d_mat_frag,
+        void* a_mat_frag,
+        void* b_mat_frag,
+        void* c_mat_frag) {
+      auto d = reinterpret_cast<volatile CDType**>(d_mat_frag);
+      auto a =
+          reinterpret_cast<typename MMAType<ABType>::PackType*>(a_mat_frag);
+      auto b =
+          reinterpret_cast<typename MMAType<ABType>::PackType*>(b_mat_frag);
+      auto c = reinterpret_cast<CDType*>(c_mat_frag);
+
+      auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();
+      int lane = sg.get_local_linear_id();
+
+      static_assert(
+          (M == 8 && N == 8 && K == 4) || (M == 8 && N == 8 && K == 16) ||
+              (M == 16 && N == 8 && K == 8) || (M == 16 && N == 8 && K == 16) ||
+              (M == 16 && N == 8 && K == 32),
+          "Unsupported MMA shape!");
+
+      short row_load_offset = 4 * (lane >> 2);
+      short col_load_offset = 8 * (lane % 4);
+
+      if constexpr (M == 8 && N == 8 && K == 4) {
+        if constexpr (std::is_floating_point_v<CDType>) {
+          col_load_offset = row_load_offset % 16;
+
+          // Init D matrix with fragments of C matrix
+          *d[0] = c[0];
+          *d[1] = c[1];
+          *d[2] = c[2];
+          *d[3] = c[3];
+          *d[4] = c[4];
+          *d[5] = c[5];
+          *d[6] = c[6];
+          *d[7] = c[7];
+
+          // Calculate the row and col offset indices to iterate through the row
+          // & col fragments of A & B matrices
+          int r_ind = (lane % 2) ? 1 : 0;
+          int c_ind = ((lane % 4) / 2) ? 2 : 0;
+
+          // Each sub-group is responsible for computing a fragment size of 8*8
+          // elements of matrix D for each of 4 MMA computations.
+          // Each work item computes 8 elements of matrix D by gathering
+          // their corresponding col & row matrix fragments of length k (4)
+          // from A & B matrices respectively using below mapping logic:
+          // row0 = (i % 4) if (lane < 16) else (i % 4) + 4
+          // col0 = (lane % 4)
+          // As each row & col fragment of A & B matrices is distributed across
+          // 4 work items, each iteration of below loop loads a partial fragment
+          // of matrix A (row) and matrix B (col) using the row & col offsets.
+          typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
+
+          for (int i = 0; i < 4; i++) {
+            // Load partial fragment from col0 of matrix A ({a0, a1})
+            recv_a[0] =
+                dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
+            // Load partial fragment from col0 of matrix A ({a2, a3})
+            recv_a[1] =
+                dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
+
+            // Load partial fragment from row0 of matrix B ({b0, b1})
+            recv_b[0] =
+                dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
+            // Load partial fragment from row0 of matrix B ({b2, b3})
+            recv_b[1] =
+                dpct::select_from_sub_group(sg, b[1], col_load_offset + i);
+
+            auto ra = reinterpret_cast<ABType*>(recv_a);
+            auto rb = reinterpret_cast<ABType*>(recv_b);
+
+            // Each work item calculates a partial product of A & B matrix
+            // fragments and adds it to the corresponding D matrix fragment (for
+            // even work item indices) d0 += col0{ a0 } * row0{ b0 } d1 += col0{
+            // a0 } * row0{ b1 } d2 += col1{ a2 } * row0{ b0 } d3 += col1{ a2 }
+            // * row0{ b1 } (for odd work item indices) d0 += col0{ a1 } * row0{
+            // b2 } d1 += col0{ a1 } * row0{ b3 } d2 += col1{ a3 } * row0{ b2 }
+            // d3 += col1{ a3 } * row0{ b3 }
+            *d[0] +=
+                static_cast<float>(ra[r_ind]) * static_cast<float>(rb[c_ind]);
+            *d[1] += static_cast<float>(ra[r_ind]) *
+                     static_cast<float>(rb[c_ind + 1]);
+            *d[2] += static_cast<float>(ra[r_ind + 2]) *
+                     static_cast<float>(rb[c_ind]);
+            *d[3] += static_cast<float>(ra[r_ind + 2]) *
+                     static_cast<float>(rb[c_ind + 1]);
+
+            // Load partial fragment from row1 of matrix B ({b0, b1})
+            recv_b[0] =
+                dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 16);
+            // Load partial fragment from row1 of matrix B ({b2, b3})
+            recv_b[1] =
+                dpct::select_from_sub_group(sg, b[1], col_load_offset + i + 16);
+
+            // (for even work item indices)
+            // d0 += col0{ a0 } * row1{ b0 }
+            // d1 += col0{ a0 } * row1{ b1 }
+            // d2 += col1{ a2 } * row1{ b0 }
+            // d3 += col1{ a2 } * row1{ b1 }
+            // (for odd work item indices)
+            // d0 += col0{ a1 } * row1{ b2 }
+            // d1 += col0{ a1 } * row1{ b3 }
+            // d2 += col1{ a3 } * row1{ b2 }
+            // d3 += col1{ a3 } * row1{ b3 }
+            *d[4] +=
+                static_cast<float>(ra[r_ind]) * static_cast<float>(rb[c_ind]);
+            *d[5] += static_cast<float>(ra[r_ind]) *
+                     static_cast<float>(rb[c_ind + 1]);
+            *d[6] += static_cast<float>(ra[r_ind + 2]) *
+                     static_cast<float>(rb[c_ind]);
+            *d[7] += static_cast<float>(ra[r_ind + 2]) *
+                     static_cast<float>(rb[c_ind + 1]);
+          }
+        }
+      } else if constexpr (M == 8 && N == 8 && K == 16) {
+        if constexpr (std::is_integral_v<ABType>) {
+          // Init D matrix with fragments of C matrix
+          *d[0] = c[0];
+          *d[1] = c[1];
+
+          // Each sub-group is responsible for computing a fragment size of 16*8
+          // elements of matrix D.
+          // Each work item computes 2 elements of matrix D by gathering
+          // their corresponding row & col matrix fragments of length k (16)
+          // from A & B matrices respectively using below mapping logic:
+          // row0 = ((lane % 4) * 4) + i
+          // col0 = (lane >> 2)
+          // As each row & col fragment of A & B matrices is distributed across
+          // 4 work items, each iteration of below loop loads a partial fragment
+          // of matrix A (row) and matrix B (col) using the row & col offsets.
+          for (int i = 0; i < 4; i++) {
+            typename MMAType<ABType>::PackType recv_a, recv_b[2];
+
+            // Load partial fragment from row0 of matrix A ({a0, a1, a2, a3})
+            recv_a = dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
+            // Load partial fragment from col0 of matrix B ({b0, b1, b2, b3})
+            recv_b[0] =
+                dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
+            // Load partial fragment from col1 of matrix B ({b0, b1, b2, b3})
+            recv_b[1] =
+                dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
+
+            auto a = reinterpret_cast<ABType*>(&recv_a);
+            auto b = reinterpret_cast<ABType*>(recv_b);
+
+            // Each work item calculates a partial product of A & B matrix
+            // fragments and adds it to the corresponding D matrix fragment d0
+            // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
+            // a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row0{ a0, a1, a2,
+            // a3 } * col0{ b0, b1, b2, b3 } d3 += row0{ a0, a1, a2, a3 } *
+            // col1{ b0, b1, b2, b3 }
+            for (int j = 0; j < 4; j++) {
+              *d[0] += a[j] * b[j];
+              *d[1] += a[j] * b[j + 4];
+            }
+          }
+        }
+      } else if constexpr (M == 16 && N == 8 && K == 8) {
+        if constexpr (std::is_floating_point_v<CDType>) {
+          // Init D matrix fragment with C matrix fragment
+          *d[0] = c[0];
+          *d[1] = c[1];
+          *d[2] = c[2];
+          *d[3] = c[3];
+
+          // Each sub-group is responsible for computing a fragment size of 16*8
+          // elements of matrix D.
+          // Each work item computes 4 elements of matrix D by gathering
+          // their corresponding row & col matrix fragments of length k (8)
+          // from A & B matrices respectively using below mapping logic:
+          // row0 = (lane >> 2) & row1 = (lane >> 2) + 8
+          // col0 = (lane % 4) * 2 + (i & 0x1)
+          // As each row & col fragment of A & B matrices is distributed across
+          // 4 work items, each iteration of below loop loads a partial fragment
+          // of matrix A (row) and matrix B (col) using the row & col offsets.
+          for (int i = 0; i < 4; i++) {
+            typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
+
+            // Load partial fragment from row0 of matrix A ({a0, a1})
+            recv_a[0] =
+                dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
+            // Load partial fragment from row1 of matrix A ({a2, a3})
+            recv_a[1] =
+                dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
+            // Load partial fragment from col0 of matrix B ({b0, b1})
+            recv_b[0] =
+                dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
+            // Load partial fragment from col1 of matrix B ({b0, b1})
+            recv_b[1] =
+                dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
+
+            auto ra = reinterpret_cast<ABType*>(recv_a);
+            auto rb = reinterpret_cast<ABType*>(recv_b);
+
+            // Each work item calculates a partial product of A & B matrix
+            // fragments and adds it to the corresponding D matrix fragment d0
+            // += row0{ a0, a1 } * col0{ b0, b1 } d1 += row0{ a0, a1 } * col1{
+            // b0, b1 } d2 += row1{ a2, a3 } * col0{ b0, b1 } d3 += row1{ a2, a3
+            // } * col1{ b0, b1 }
+            for (int j = 0; j < 2; j++) {
+              *d[0] += static_cast<float>(ra[j]) * static_cast<float>(rb[j]);
+              *d[1] +=
+                  static_cast<float>(ra[j]) * static_cast<float>(rb[j + 2]);
+              *d[2] +=
+                  static_cast<float>(ra[j + 2]) * static_cast<float>(rb[j]);
+              *d[3] +=
+                  static_cast<float>(ra[j + 2]) * static_cast<float>(rb[j + 2]);
+            }
+          }
+        }
+      } else if constexpr (M == 16 && N == 8 && K == 16) {
+        if constexpr (std::is_floating_point_v<CDType>) {
+          // Init D matrix fragment with C matrix fragment
+          *d[0] = c[0];
+          *d[1] = c[1];
+          *d[2] = c[2];
+          *d[3] = c[3];
+
+          // Each sub-group is responsible for computing a fragment size of 16*8
+          // elements of matrix D.
+          // Each work item computes 4 elements of matrix D by gathering
+          // their corresponding row & col matrix fragments of length k (8)
+          // from A & B matrices respectively using below mapping logic:
+          // row0 = (lane >> 2)    & row1 = (lane >> 2) + 8
+          // col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1
+          // As each row & col fragment of A & B matrices is distributed across
+          // 4 work items, each iteration of below loop loads a partial fragment
+          // of matrix A (row) and matrix B (col) using the row & col offsets.
+          for (int i = 0; i < 4; i++) {
+            typename MMAType<ABType>::PackType recv_a[4], recv_b[4];
+
+            // Load partial fragment from row0 of matrix A ({a0, a1})
+            recv_a[0] =
+                dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
+            // Load partial fragment from row0 of matrix A ({a2, a3})
+            recv_a[1] =
+                dpct::select_from_sub_group(sg, a[2], row_load_offset + i);
+            // Load partial fragment from row1 of matrix A ({a0, a1})
+            recv_a[2] =
+                dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
+            // Load partial fragment from row1 of matrix A ({a2, a3})
+            recv_a[3] =
+                dpct::select_from_sub_group(sg, a[3], row_load_offset + i);
+
+            // Load partial fragment from col0 of matrix B ({b0, b1})
+            recv_b[0] =
+                dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
+            // Load partial fragment from col0 of matrix B ({b2, b3})
+            recv_b[1] =
+                dpct::select_from_sub_group(sg, b[1], col_load_offset + i);
+            // Load partial fragment from col1 of matrix B ({b0, b1})
+            recv_b[2] =
+                dpct::select_from_sub_group(sg, b[0], col_load_offset + 4 + i);
+            // Load partial fragment from col1 of matrix B ({b2, b3})
+            recv_b[3] =
+                dpct::select_from_sub_group(sg, b[1], col_load_offset + 4 + i);
+
+            auto ra = reinterpret_cast<ABType*>(recv_a);
+            auto rb = reinterpret_cast<ABType*>(recv_b);
+
+            // Each work item calculates a partial product of A & B matrix
+            // fragments and adds it to the corresponding D matrix fragment d0
+            // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
+            // a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row1{ a0, a1, a2,
+            // a3 } * col0{ b0, b1, b2, b3 } d3 += row1{ a0, a1, a2, a3 } *
+            // col1{ b0, b1, b2, b3 }
+            for (int j = 0; j < 4; j++) {
+              *d[0] += static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[j]);
+              *d[1] +=
+                  static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[j + 4]);
+              *d[2] +=
+                  static_cast<CDType>(ra[j + 4]) * static_cast<CDType>(rb[j]);
+              *d[3] += static_cast<CDType>(ra[j + 4]) *
+                       static_cast<CDType>(rb[j + 4]);
+            }
+          }
+        } else if constexpr (std::is_integral_v<ABType>) {
+          // Init D matrix with fragments of C matrix
+          *d[0] = c[0];
+          *d[1] = c[1];
+          *d[2] = c[2];
+          *d[3] = c[3];
+
+          // Each sub-group is responsible for computing a fragment size of 16*8
+          // elements of matrix D.
+          // Each work item computes 4 elements of matrix D by gathering
+          // their corresponding row & col matrix fragments of length k (8)
+          // from A & B matrices respectively using below mapping logic:
+          // row0 = (lane >> 2)    & row1 = (lane >> 2) + 8
+          // col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1
+          // As each row & col fragment of A & B matrices is distributed across
+          // 4 work items, each iteration of below loop loads a partial fragment
+          // of matrix A (row) and matrix B (col) using the row & col offsets.
+          for (int i = 0; i < 4; i++) {
+            typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
+
+            // Load partial fragment from row0 of matrix A ({a0, a1, a2, a3})
+            recv_a[0] =
+                dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
+            // Load partial fragment from row1 of matrix A ({a4, a5, a6, a7})
+            recv_a[1] =
+                dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
+            // Load partial fragment from col0 of matrix B ({b0, b1, b2, b3})
+            recv_b[0] =
+                dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
+            // Load partial fragment from col1 of matrix B ({b4, b5, b6, b7})
+            recv_b[1] =
+                dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
+
+            auto ra = reinterpret_cast<ABType*>(recv_a);
+            auto rb = reinterpret_cast<ABType*>(recv_b);
+
+            // Each work item calculates a partial product of A & B matrix
+            // fragments and adds it to the corresponding D matrix fragment d0
+            // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
+            // a0, a1, a2, a3 } * col1{ b4, b5, b6, b7 } d2 += row1{ a4, a5, a6,
+            // a7 } * col0{ b0, b1, b2, b3 } d3 += row1{ a4, a5, a6, a7 } *
+            // col1{ b4, b5, b6, b7 }
+            for (int i = 0; i < 4; i++) {
+              *d[0] += ra[i] * rb[i];
+              *d[1] += ra[i] * rb[i + 4];
+              *d[2] += ra[i + 4] * rb[i];
+              *d[3] += ra[i + 4] * rb[i + 4];
+            }
+          }
+        }
+      } else if constexpr (M == 16 && N == 8 && K == 32) {
+        if constexpr (std::is_integral_v<ABType>) {
+          // Init D matrix with fragments of C matrix
+          *d[0] = c[0];
+          *d[1] = c[1];
+          *d[2] = c[2];
+          *d[3] = c[3];
+
+          // Each sub-group is responsible for computing a fragment size of 16*8
+          // elements of matrix D.
+          // Each work item computes 4 elements of matrix D by gathering
+          // their corresponding row & col matrix fragments of length k (32)
+          // from A & B matrices respectively using below mapping logic:
+          // row0 = (lane >> 2)    & row1 = (lane >> 2) + 8
+          // col0 = ((lane % 4) * 4) + (i & 0x3) & col1 = ((lane % 4) * 4) + (i
+          // & 0x3) As each row & col fragment of A & B matrices is distributed
+          // across 4 work items, each iteration of below loop loads a partial
+          // fragment of matrix A (row) and matrix B (col) using the row & col
+          // offsets.
+          for (int i = 0; i < 4; i++) {
+            typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
+
+            // Load partial fragment from row0 of matrix A ({a0, a1, a2, a3})
+            recv_a[0] =
+                dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
+            // Load partial fragment from row1 of matrix A ({a4, a5, a6, a7})
+            recv_a[1] =
+                dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
+            // Load partial fragment from col0 of matrix B ({b0, b1, b2, b3})
+            recv_b[0] =
+                dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
+            // Load partial fragment from col1 of matrix B ({b0, b1, b2, b3})
+            recv_b[1] =
+                dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
+
+            auto a = reinterpret_cast<ABType*>(recv_a);
+            auto b = reinterpret_cast<ABType*>(recv_b);
+
+            // Each work item calculates a partial product of A & B matrix
+            // fragments and adds it to the corresponding D matrix fragment d0
+            // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
+            // a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row1{ a4, a5, a6,
+            // a7 } * col0{ b0, b1, b2, b3 } d3 += row1{ a4, a5, a6, a7 } *
+            // col1{ b0, b1, b2, b3 }
+            for (int j = 0; j < 4; j++) {
+              *d[0] += a[j] * b[j];
+              *d[1] += a[j] * b[j + 4];
+              *d[2] += a[j + 4] * b[j];
+              *d[3] += a[j + 4] * b[j + 4];
+            }
+          }
+
+          for (int i = 0; i < 4; i++) {
+            typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
+
+            // Load partial fragment from row0 of matrix A ({a8, a9, a10, a11})
+            recv_a[0] =
+                dpct::select_from_sub_group(sg, a[2], row_load_offset + i);
+            // Load partial fragment from row1 of matrix A ({a12, a13, a14,
+            // a15})
+            recv_a[1] =
+                dpct::select_from_sub_group(sg, a[3], row_load_offset + i);
+            // Load partial fragment from col0 of matrix B ({b4, b5, b6, b7})
+            recv_b[0] =
+                dpct::select_from_sub_group(sg, b[1], col_load_offset + i);
+            // Load partial fragment from col1 of matrix B ({b4, b5, b6, b7})
+            recv_b[1] =
+                dpct::select_from_sub_group(sg, b[1], col_load_offset + i + 4);
+
+            auto a = reinterpret_cast<ABType*>(recv_a);
+            auto b = reinterpret_cast<ABType*>(recv_b);
+
+            // Each work item calculates a partial product of A & B matrix
+            // fragments and adds it to the corresponding D matrix fragment d0
+            // += row0{ a8, a9, a10, a11 } * col0{ b4, b5, b6, b7 } d1 += row0{
+            // a8, a9, a10, a11 } * col1{ b4, b5, b6, b7 } d2 += row1{ a12, a13,
+            // a14, a15 } * col0{ b4, b5, b6, b7 } d3 += row1{ a12, a13, a14,
+            // a15 } * col1{ b4, b5, b6, b7 }
+            for (int j = 0; j < 4; j++) {
+              *d[0] += a[j] * b[j];
+              *d[1] += a[j] * b[j + 4];
+              *d[2] += a[j + 4] * b[j];
+              *d[3] += a[j + 4] * b[j + 4];
+            }
+          }
+        }
+      }
+    }
 } // COPY from DPCT head files
 
 #endif // GGML_SYCL_DPCT_HELPER_HPP
diff --git a/ggml/src/ggml-sycl/fattn-common.hpp b/ggml/src/ggml-sycl/fattn-common.hpp
new file mode 100644 (file)
index 0000000..ed00d03
--- /dev/null
@@ -0,0 +1,1179 @@
+#pragma once
+
+#include <sycl/sycl.hpp>
+#include "dpct/helper.hpp"
+#include "common.hpp"
+#include "convert.hpp"
+#include "vecdotq.hpp"
+
+#include "ggml.h"
+
+#include <cstdint>
+#include <cmath>
+#include <float.h>
+
+
+#define FATTN_KQ_STRIDE       256
+#define HALF_MAX_HALF         sycl::half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
+#define SOFTMAX_FTZ_THRESHOLD -20.0f                   // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
+#define FATTN_KQ_MAX_OFFSET (3.0f*0.6931f)
+
+typedef void (*fattn_kernel_t)(
+    const char* Q,
+    const char* K,
+    const char* V,
+    const char* mask,
+    const char* sinks,
+    const int* KV_max,
+    float* dst,
+    sycl::float2* dst_meta,
+    const float scale,
+    const float max_bias,
+    const float m0,
+    const float m1,
+    const uint32_t n_head_log2,
+    const float logit_softcap,
+    const int32_t ne00,
+    const sycl::uint3 ne01,
+    const int32_t ne02,
+    const int32_t ne03,
+    const int32_t nb01,
+    const int32_t nb02,
+    const int32_t nb03,
+    const int32_t ne10,
+    const int32_t ne11,
+    const int32_t ne12,
+    const int32_t ne13,
+    const int32_t nb11,
+    const int32_t nb12,
+    const int64_t nb13,
+    const int32_t nb21,
+    const int32_t nb22,
+    const int64_t nb23,
+    const int32_t ne31,
+    const int32_t ne32,
+    const int32_t ne33,
+    const int32_t nb31,
+    const int32_t nb32,
+    const int64_t nb33);
+
+typedef float (*vec_dot_KQ_t)(
+    const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
+
+template <int D, int nthreads>
+static __dpct_inline__ float vec_dot_fattn_vec_KQ_f16(const char * __restrict__ K_c,
+                                                      const void * __restrict__ Q_v,
+                                                      const int * __restrict__ Q_q8,
+                                                      const void * __restrict__ Q_ds_v) {
+    const sycl::half2 * K_h2 = (const sycl::half2 *) K_c;
+    GGML_UNUSED(Q_q8);
+    GGML_UNUSED(Q_ds_v);
+
+    constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();
+    constexpr int cpy_ne = cpy_nb / 4;
+
+    float sum = 0.0f;
+
+#pragma unroll
+    for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) {
+        sycl::half2 tmp[cpy_ne];
+        ggml_sycl_memcpy_1<sizeof(tmp)>(
+            tmp,
+            K_h2 + k_KQ_0 + (sycl::ext::oneapi::this_work_item::get_nd_item<3>().get_local_id(2) % nthreads) * cpy_ne);
+#pragma unroll
+        for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
+#ifdef GGML_SYCL_F16
+            ggml_sycl_mad(sum,                tmp[k_KQ_1] , ((const sycl::half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
+#else
+            ggml_sycl_mad(sum, __half22float2(tmp[k_KQ_1]), ((const sycl::float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
+#endif // GGML_SYCL_F16
+        }
+    }
+
+    return sum;
+}
+
+template <int D, int nthreads, int warp_size>
+static __dpct_inline__ float vec_dot_fattn_vec_KQ_q4_0(const char * __restrict__ K_c,
+                                                       const void * __restrict__ Q_v,
+                                                       const int * __restrict__ Q_q8,
+                                                       const void * __restrict__ Q_ds_v) {
+    auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+
+    const block_q4_0 * K_q4_0   = (const block_q4_0 *) K_c;
+    GGML_UNUSED(Q_v);
+
+    float sum = 0.0f;
+
+#pragma unroll
+    for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
+        const int k_KQ =
+            k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads);
+
+        const int ib    = k_KQ /  QI8_1;
+        const int iqs4  = k_KQ %  QI4_0;
+        const int shift = k_KQ & (QI8_1/2);
+
+        int v;
+        ggml_sycl_memcpy_1<sizeof(int), 2>(&v, K_q4_0[ib].qs + sizeof(int)*iqs4);
+        v = (v >> shift) & 0x0F0F0F0F;
+        const int u = Q_q8[k_KQ_0/nthreads];
+
+        const int sumi = ggml_sycl_dp4a(v, u, 0);
+
+        const sycl::float2 Q_ds = ((const sycl::float2 *) Q_ds_v)[k_KQ_0 / nthreads];
+        sum += __half2float(K_q4_0[ib].d) * (sumi*Q_ds.x() - (8/QI8_1)*Q_ds.y());
+    }
+
+    return sum;
+}
+
+template <int D, int nthreads , int warp_size>
+static __dpct_inline__ float vec_dot_fattn_vec_KQ_q4_1(const char * __restrict__ K_c,
+                                                       const void * __restrict__ Q_v,
+                                                       const int * __restrict__ Q_q8,
+                                                       const void * __restrict__ Q_ds_v) {
+    auto               item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    const block_q4_1 * K_q4_1   = (const block_q4_1 *) K_c;
+    GGML_UNUSED(Q_v);
+
+    float sum = 0.0f;
+
+#pragma unroll
+    for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
+        const int k_KQ =
+            k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads);
+
+        const int ib    = k_KQ /  QI8_1;
+        const int iqs4  = k_KQ %  QI4_1;
+        const int shift = k_KQ & (QI8_1/2);
+
+        int v;
+        ggml_sycl_memcpy_1<sizeof(int)>(&v, K_q4_1[ib].qs + sizeof(int)*iqs4);
+        v = (v >> shift) & 0x0F0F0F0F;
+        const int u = Q_q8[k_KQ_0/nthreads];
+
+        const int sumi = ggml_sycl_dp4a(v, u, 0);
+
+        const sycl::float2 K_dm = (K_q4_1[ib].dm).template convert<float, sycl::rounding_mode::automatic>();
+        const sycl::float2 Q_ds = ((const sycl::float2 *) Q_ds_v)[k_KQ_0 / nthreads];
+
+        sum += K_dm.x()*Q_ds.x()*sumi + K_dm.y()*Q_ds.y()/QI8_1;
+    }
+
+    return sum;
+}
+
+template <int D, int nthreads, int warp_size>
+static __dpct_inline__ float vec_dot_fattn_vec_KQ_q5_0(const char * __restrict__ K_c,
+                                                       const void * __restrict__ Q_v,
+                                                       const int * __restrict__ Q_q8,
+                                                       const void * __restrict__ Q_ds_v) {
+    auto               item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    const block_q5_0 * K_q5_0   = (const block_q5_0 *) K_c;
+    GGML_UNUSED(Q_v);
+
+    float sum = 0.0f;
+
+#pragma unroll
+    for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
+        const int k_KQ =
+            k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads);
+
+        const int ib    = k_KQ /  QI8_1;
+        const int iqs4  = k_KQ %  QI5_0;
+        const int iqs8  = k_KQ %  QI8_1;
+        const int shift = k_KQ & (QI8_1/2);
+
+        int v;
+        ggml_sycl_memcpy_1<sizeof(int), 2>(&v, K_q5_0[ib].qs + sizeof(int)*iqs4);
+        v = (v >> shift) & 0x0F0F0F0F;
+
+        {
+            int vh;
+            ggml_sycl_memcpy_1<sizeof(int), 2>(&vh, K_q5_0[ib].qh);
+            vh >>= iqs8 * QI5_0;
+
+            v |= (vh <<  4) & 0x00000010; // 0 ->  4
+            v |= (vh << 11) & 0x00001000; // 1 -> 12
+            v |= (vh << 18) & 0x00100000; // 2 -> 20
+            v |= (vh << 25) & 0x10000000; // 3 -> 28
+        }
+
+        const int u = Q_q8[k_KQ_0/nthreads];
+
+        const int sumi = ggml_sycl_dp4a(v, u, 0);
+
+        const sycl::float2 Q_ds = ((const sycl::float2 *) Q_ds_v)[k_KQ_0 / nthreads];
+
+        sum += __half2float(K_q5_0[ib].d) * (sumi*Q_ds.x() - (16/QI8_1)*Q_ds.y());
+    }
+
+    return sum;
+}
+
+template <int D, int nthreads, int warp_size>
+static __dpct_inline__ float vec_dot_fattn_vec_KQ_q5_1(const char * __restrict__ K_c,
+                                                       const void * __restrict__ Q_v,
+                                                       const int * __restrict__ Q_q8,
+                                                       const void * __restrict__ Q_ds_v) {
+    auto               item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    const block_q5_1 * K_q5_1   = (const block_q5_1 *) K_c;
+    GGML_UNUSED(Q_v);
+
+    float sum = 0.0f;
+
+#pragma unroll
+    for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
+        const int k_KQ =
+            k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads);
+
+        const int ib    = k_KQ /  QI8_1;
+        const int iqs4  = k_KQ %  QI5_1;
+        const int iqs8  = k_KQ %  QI8_1;
+        const int shift = k_KQ & (QI8_1/2);
+
+        int v;
+        ggml_sycl_memcpy_1<sizeof(int)>(&v, K_q5_1[ib].qs + sizeof(int)*iqs4);
+        v = (v >> shift) & 0x0F0F0F0F;
+
+        {
+            int vh;
+            ggml_sycl_memcpy_1<sizeof(int)>(&vh, K_q5_1[ib].qh);
+            vh >>= iqs8 * QI5_0;
+
+            v |= (vh <<  4) & 0x00000010; // 0 ->  4
+            v |= (vh << 11) & 0x00001000; // 1 -> 12
+            v |= (vh << 18) & 0x00100000; // 2 -> 20
+            v |= (vh << 25) & 0x10000000; // 3 -> 28
+        }
+
+        const int u = Q_q8[k_KQ_0/nthreads];
+
+        const int sumi = ggml_sycl_dp4a(v, u, 0);
+
+        const sycl::float2 K_dm = (K_q5_1[ib].dm).template convert<float, sycl::rounding_mode::automatic>();
+        const sycl::float2 Q_ds = ((const sycl::float2 *) Q_ds_v)[k_KQ_0 / nthreads];
+
+        sum += K_dm.x()*Q_ds.x()*sumi + K_dm.y()*Q_ds.y()/QI8_1;
+    }
+
+    return sum;
+}
+
+template <int D, int nthreads, int warp_size>
+static __dpct_inline__ float vec_dot_fattn_vec_KQ_q8_0(const char * __restrict__ K_c,
+                                                       const void * __restrict__ Q_v,
+                                                       const int * __restrict__ Q_q8,
+                                                       const void * __restrict__ Q_ds_v) {
+    auto               item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    const block_q8_0 * K_q8_0   = (const block_q8_0 *) K_c;
+    GGML_UNUSED(Q_v);
+
+    float sum = 0.0f;
+
+#pragma unroll
+    for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
+        const int k_KQ =
+            k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads);
+
+        const int ib  = k_KQ / QI8_0;
+        const int iqs = k_KQ % QI8_0;
+
+        int v;
+        ggml_sycl_memcpy_1<sizeof(v), 2>(&v, K_q8_0[ib].qs + 4*iqs);
+
+        const sycl::float2 * Q_ds = (const sycl::float2 *) Q_ds_v;
+        const float          Q_d  = Q_ds[k_KQ_0 / nthreads].x();
+
+        sum += vec_dot_q8_0_q8_1_impl<float, 1>(&v, &Q_q8[k_KQ_0/nthreads], K_q8_0[ib].d, Q_d);
+    }
+
+    return sum;
+}
+
+template <typename Tds, int ni, int warp_size>
+static __dpct_inline__ void quantize_q8_1_to_shared(const float * __restrict__ x,
+                                                    const float scale,
+                                                    int * __restrict__ yq32,
+                                                    void * __restrict__ yds) {
+    auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+
+    float vals[sizeof(int)] = { 0.0f };
+#pragma unroll
+    for (int l = 0; l < int(sizeof(int)); ++l) {
+        vals[l] =
+            (ni == warp_size || item_ct1.get_local_id(2) < ni) ? scale * x[4 * item_ct1.get_local_id(2) + l] : 0.0f;
+    }
+
+    float amax = sycl::fabs(vals[0]);
+    float sum  = vals[0];
+#pragma unroll
+    for (int l = 1; l < int(sizeof(int)); ++l) {
+        amax = sycl::fmax(amax, sycl::fabs(vals[l]));
+        sum += vals[l];
+    }
+#pragma unroll
+    for (int mask = QI8_1/2; mask > 0; mask >>= 1) {
+        amax = sycl::fmax(
+            amax, dpct::permute_sub_group_by_xor(sycl::ext::oneapi::this_work_item::get_sub_group(), amax, mask));
+        sum += dpct::permute_sub_group_by_xor(sycl::ext::oneapi::this_work_item::get_sub_group(), sum, mask);
+    }
+
+    const float d = amax / 127;
+    int q32 = 0;
+    int8_t * q8 = (int8_t *) &q32;
+
+    if (d != 0.0f) {
+#pragma unroll
+        for (int l = 0; l < int(sizeof(int)); ++l) {
+            q8[l] = sycl::round(vals[l] / d);
+        }
+    }
+
+    yq32[item_ct1.get_local_id(2)] = q32;
+    if (item_ct1.get_local_id(2) % QI8_1 == 0 && (ni == warp_size || item_ct1.get_local_id(2) < ni)) {
+        if (std::is_same<Tds, sycl::half2>::value) {
+            ((sycl::half2  *) yds)[item_ct1.get_local_id(2)/QI8_1] =  make_half2(d, sum);
+        } else {
+            ((sycl::float2 *) yds)[item_ct1.get_local_id(2)/QI8_1] = make_float2(d, sum);
+        }
+    }
+}
+
+typedef void (*dequantize_V_t)(const void *, void *, const int64_t);
+
+template <typename T, int ne>
+static __dpct_inline__ void dequantize_V_f16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
+    if constexpr (std::is_same_v<T, sycl::half>) {
+        ggml_sycl_memcpy_1<ne * sizeof(sycl::half)>(dst, (const sycl::half *) vx + i0);
+    } else if constexpr (std::is_same_v<T, float>) {
+        static_assert(ne % 2 == 0, "bad ne");
+        sycl::half2 tmp[ne / 2];
+        ggml_sycl_memcpy_1<ne * sizeof(sycl::half)>(tmp, (const sycl::half *) vx + i0);
+        sycl::float2 * dst_f2 = (sycl::float2 *) dst;
+#pragma unroll
+        for (int l = 0; l < ne/2; ++l) {
+            dst_f2[l] = tmp[l].template convert<float, sycl::rounding_mode::automatic>();
+        }
+    } else {
+        static_assert(std::is_same_v<T, void>, "unsupported type");
+    }
+}
+
+template <typename T, int ne>
+static __dpct_inline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
+    const block_q4_0 * x = (const block_q4_0 *) vx;
+
+    const int64_t ib    =  i0          /  QK4_0;
+    const int     iqs   =  i0          % (QK4_0/2);
+    const int     shift = (i0 % QK4_0) / (QK4_0/2);
+
+    int q;
+    static_assert(ne == 2 || ne == 4, "bad ne");
+    ggml_sycl_memcpy_1<ne, 2>(&q, x[ib].qs + iqs);
+    q >>= 4*shift;
+    q &= 0x0F0F0F0F;
+    q = dpct::vectorized_binary<sycl::char4>(q, 0x08080808, dpct::sub_sat());
+
+    const int8_t * q8 = (const int8_t *) &q;
+
+#ifdef GGML_SYCL_F16
+    if constexpr (std::is_same_v<T, sycl::half>) {
+        const sycl::half2 d = sycl::half2(x[ib].d);
+
+#pragma unroll
+        for (int l0 = 0; l0 < ne; l0 += 2) {
+            ((sycl::half2 *) dst)[l0 / 2] = d * sycl::half2(q8[l0 + 0], q8[l0 + 1]);
+        }
+    } else
+#endif // GGML_SYCL_F16
+    if constexpr (std::is_same_v<T, float>) {
+        const float d = x[ib].d;
+
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            ((float *) dst)[l] = d * q8[l];
+        }
+    } else {
+        static_assert(std::is_same_v<T, void>, "bad type");
+    }
+}
+
+template <typename T, int ne>
+static __dpct_inline__ void dequantize_V_q4_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
+    const block_q4_1 * x = (const block_q4_1 *) vx;
+
+    const int64_t ib    =  i0          /  QK4_1;
+    const int     iqs   =  i0          % (QK4_1/2);
+    const int     shift = (i0 % QK4_1) / (QK4_1/2);
+
+    int q;
+    static_assert(ne == 2 || ne == 4, "bad ne");
+    ggml_sycl_memcpy_1<ne>(&q, x[ib].qs + iqs);
+    q >>= 4*shift;
+    q &= 0x0F0F0F0F;
+
+    const int8_t * q8 = (const int8_t *) &q;
+
+#ifdef GGML_SYCL_F16
+    if constexpr (std::is_same_v<T, sycl::half>) {
+        const sycl::half2 dm = x[ib].dm;
+        const sycl::half2 d  = sycl::half2(dm[0]);
+        const sycl::half2 m  = sycl::half2(dm[1]);
+
+#pragma unroll
+        for (int l0 = 0; l0 < ne; l0 += 2) {
+            ((sycl::half2 *) dst)[l0 / 2] = d * sycl::half2(q8[l0 + 0], q8[l0 + 1]) + m;
+        }
+    } else
+#endif // GGML_SYCL_F16
+    if constexpr (std::is_same_v<T, float>) {
+        const sycl::float2 dm = (x[ib].dm).template convert<float, sycl::rounding_mode::automatic>();
+
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            ((float *) dst)[l] = dm.x() * q8[l] + dm.y();
+        }
+    } else {
+        static_assert(std::is_same_v<T, void>, "bad type");
+    }
+}
+
+template <typename T, int ne>
+static __dpct_inline__ void dequantize_V_q5_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
+    const block_q5_0 * x = (const block_q5_0 *) vx;
+
+    const int64_t ib    =  i0          /  QK5_0;
+    const int     idq   =  i0          %  QK5_0;
+    const int     iqs   =  i0          % (QK5_0/2);
+    const int     shift = (i0 % QK5_0) / (QK5_0/2);
+
+    int q;
+    static_assert(ne == 2 || ne == 4, "bad ne");
+    ggml_sycl_memcpy_1<ne, 2>(&q, x[ib].qs + iqs);
+    q >>= 4*shift;
+    q &= 0x0F0F0F0F;
+
+    {
+        int qh;
+        ggml_sycl_memcpy_1<ne, 2>(&qh, x[ib].qh);
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4);
+        }
+    }
+
+    q = dpct::vectorized_binary<sycl::char4>(q, 0x10101010, dpct::sub_sat());
+
+    const int8_t * q8 = (const int8_t *) &q;
+
+#ifdef GGML_SYCL_F16
+    if constexpr (std::is_same_v<T, sycl::half>) {
+        const sycl::half2 d = sycl::half2(x[ib].d);
+
+#pragma unroll
+        for (int l0 = 0; l0 < ne; l0 += 2) {
+            ((sycl::half2 *) dst)[l0 / 2] = d * sycl::half2(q8[l0 + 0], q8[l0 + 1]);
+        }
+    } else
+#endif // GGML_SYCL_F16
+    if constexpr (std::is_same_v<T, float>) {
+        const float d = x[ib].d;
+
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            ((float *) dst)[l] = d * q8[l];
+        }
+    } else {
+        static_assert(std::is_same_v<T, void>, "bad type");
+    }
+}
+
+template <typename T, int ne>
+static __dpct_inline__ void dequantize_V_q5_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
+    const block_q5_1 * x = (const block_q5_1 *) vx;
+
+    const int64_t ib    =  i0          /  QK5_1;
+    const int     idq   =  i0          %  QK5_1;
+    const int     iqs   =  i0          % (QK5_1/2);
+    const int     shift = (i0 % QK5_1) / (QK5_1/2);
+
+    int q;
+    static_assert(ne == 2 || ne == 4, "bad ne");
+    ggml_sycl_memcpy_1<ne>(&q, x[ib].qs + iqs);
+    q >>= 4*shift;
+    q &= 0x0F0F0F0F;
+
+    {
+        int qh;
+        ggml_sycl_memcpy_1<ne>(&qh, x[ib].qh);
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4);
+        }
+    }
+
+    const int8_t * q8 = (const int8_t *) &q;
+
+#ifdef GGML_SYCL_F16
+    if constexpr (std::is_same_v<T, sycl::half>) {
+        const sycl::half2 dm = x[ib].dm;
+        const sycl::half2 d  = sycl::half2(dm[0]);
+        const sycl::half2 m  = sycl::half2(dm[1]);
+
+#pragma unroll
+        for (int l0 = 0; l0 < ne; l0 += 2) {
+            ((sycl::half2 *) dst)[l0 / 2] = d * sycl::half2(q8[l0 + 0], q8[l0 + 1]) + m;
+        }
+    } else
+#endif // GGML_SYCL_F16
+    if constexpr (std::is_same_v<T, float>) {
+        const sycl::float2 dm = (x[ib].dm).template convert<float, sycl::rounding_mode::automatic>();
+
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            ((float *) dst)[l] = dm.x() * q8[l] + dm.y();
+        }
+    } else {
+        static_assert(std::is_same_v<T, void>, "bad type");
+    }
+}
+
+template <typename T, int ne>
+static __dpct_inline__ void dequantize_V_q8_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
+    const block_q8_0 * x = (const block_q8_0 *) vx;
+
+    const int64_t ib  = i0 / QK8_0;
+    const int     iqs = i0 % QK8_0;
+
+    static_assert(ne % 2 == 0, "bad ne");
+    int8_t qs[ne];
+    ggml_sycl_memcpy_1<ne, 2>(qs, x[ib].qs + iqs);
+
+#ifdef GGML_SYCL_F16
+    if constexpr (std::is_same<T, sycl::half>::value) {
+        const sycl::half2 d = sycl::half2(x[ib].d);
+
+#pragma unroll
+        for (int l0 = 0; l0 < ne; l0 += 2) {
+            ((sycl::half2 *) dst)[l0 / 2] = d * make_half2(qs[l0 + 0], qs[l0 + 1]);
+        }
+    } else
+#endif // GGML_SYCL_F16
+    if constexpr (std::is_same<T, float>::value) {
+        const float d = x[ib].d;
+
+#pragma unroll
+        for (int l = 0; l < ne; ++l) {
+            ((float *) dst)[l] = d * qs[l];
+        }
+    } else {
+        static_assert(std::is_same_v<T, void>, "unsupported type");
+    }
+}
+
+template <int type_K, int D, int nthreads, int warp_size>
+constexpr vec_dot_KQ_t get_vec_dot_KQ() {
+    if constexpr (type_K == GGML_TYPE_F16) {
+        return vec_dot_fattn_vec_KQ_f16<D, nthreads>;
+    } else if constexpr (type_K == GGML_TYPE_Q4_0) {
+        return vec_dot_fattn_vec_KQ_q4_0<D, nthreads, warp_size>;
+    } else if constexpr (type_K == GGML_TYPE_Q4_1) {
+        return vec_dot_fattn_vec_KQ_q4_1<D, nthreads, warp_size>;
+    } else if constexpr (type_K == GGML_TYPE_Q5_0) {
+        return vec_dot_fattn_vec_KQ_q5_0<D, nthreads, warp_size>;
+    } else if constexpr (type_K == GGML_TYPE_Q5_1) {
+        return vec_dot_fattn_vec_KQ_q5_1<D, nthreads, warp_size>;
+    } else if constexpr (type_K == GGML_TYPE_Q8_0) {
+        return vec_dot_fattn_vec_KQ_q8_0<D, nthreads, warp_size>;
+    } else {
+        static_assert(type_K == -1, "bad type");
+        return nullptr;
+    }
+}
+
+template <int type_V, typename T, int ne>
+constexpr dequantize_V_t get_dequantize_V() {
+    if constexpr (type_V == GGML_TYPE_F16) {
+        return dequantize_V_f16<T, ne>;
+    } else if constexpr (type_V == GGML_TYPE_Q4_0) {
+        return dequantize_V_q4_0<T, ne>;
+    } else if constexpr (type_V == GGML_TYPE_Q4_1) {
+        return dequantize_V_q4_1<T, ne>;
+    } else if constexpr (type_V == GGML_TYPE_Q5_0) {
+        return dequantize_V_q5_0<T, ne>;
+    } else if constexpr (type_V == GGML_TYPE_Q5_1) {
+        return dequantize_V_q5_1<T, ne>;
+    } else if constexpr (type_V == GGML_TYPE_Q8_0) {
+        return dequantize_V_q8_0<T, ne>;
+    } else {
+        static_assert(type_V == -1, "bad type");
+        return nullptr;
+    }
+}
+
+template <int ncols1, int warp_size>
+static void flash_attn_mask_to_KV_max(const sycl::half2 * __restrict__ mask,
+                                      int * __restrict__ KV_max,
+                                      const int ne30,
+                                      const int s31,
+                                      const int s33,
+                                      int *     buf_iw) {
+    auto      item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    const int ne31     = item_ct1.get_group_range(2);
+    const int tid      = item_ct1.get_local_id(2);
+    const int sequence = item_ct1.get_group(1);
+    const int jt       = item_ct1.get_group(2);
+
+    mask += sequence*s33 + jt*ncols1*s31;
+
+    if (tid < warp_size) {
+        buf_iw[tid] = 1;
+    }
+    item_ct1.barrier(sycl::access::fence_space::local_space);
+
+    int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE;
+    for (; KV_max_sj >= 0; KV_max_sj -= FATTN_KQ_STRIDE) {
+        int all_inf = 1;
+
+#pragma unroll
+        for (int j = 0; j < ncols1; ++j) {
+            const sycl::float2 tmp =
+                mask[j * s31 + KV_max_sj / 2 + tid].template convert<float, sycl::rounding_mode::automatic>();
+            all_inf = all_inf && int(sycl::isinf((float) (tmp.x()))) && int(sycl::isinf((float) (tmp.y())));
+        }
+
+        all_inf = warp_reduce_all<warp_size>(all_inf);
+        if (tid % warp_size == 0) {
+            buf_iw[tid / warp_size] = all_inf;
+        }
+        item_ct1.barrier(sycl::access::fence_space::local_space);
+        all_inf = buf_iw[tid % warp_size];
+        item_ct1.barrier(sycl::access::fence_space::local_space);
+        all_inf = warp_reduce_all<warp_size>(all_inf);
+
+        if (!all_inf) {
+            break;
+        }
+    }
+
+    // If the break in the loop was not triggered, KV_max_sj is now -FATTN_KQ_STRIDE.
+    // If the break was triggered it's the lower edge of the tile with the first non-masked values.
+    // In either case, walk back the decrementation by FATTN_KQ_STRIDE.
+    KV_max_sj += FATTN_KQ_STRIDE;
+
+    if (item_ct1.get_local_id(2) != 0) {
+        return;
+    }
+
+    KV_max[sequence*ne31 + jt] = KV_max_sj;
+}
+
+template <int D, int ncols1, int ncols2>  // D == head size
+
+static void flash_attn_stream_k_fixup(float * __restrict__ dst,
+                                      const sycl::float2 * __restrict__ dst_fixup,
+                                      const int ne01,
+                                      const int ne02,
+                                      const int ne03,
+                                      const int ne11,
+                                      const int ne12,
+                                      const int nbatch_fa) {
+    auto          item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    constexpr int ncols    = ncols1 * ncols2;
+
+    const int bidx0 = item_ct1.get_group(2);
+    const int j     = item_ct1.get_group(1);
+    const int c     = item_ct1.get_group(0);
+    const int jc    = j*ncols2 + c;
+    const int tid   = item_ct1.get_local_id(2);
+
+    const float * dst_fixup_data = ((const float *) dst_fixup) + item_ct1.get_group_range(2) * (2 * 2 * ncols);
+
+    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+
+    const int iter_k     = (ne11      + (nbatch_fa - 1)) / nbatch_fa;
+    const int iter_j     = (ne01      + (ncols1    - 1)) / ncols1;
+    const int iter_z_gqa = (gqa_ratio + (ncols2    - 1)) / ncols2;
+
+    const int kbc0 = int64_t(bidx0 + 0) * (iter_k * iter_j * iter_z_gqa * ne12 * ne03) / item_ct1.get_group_range(2);
+    const int kbc0_stop =
+        int64_t(bidx0 + 1) * (iter_k * iter_j * iter_z_gqa * ne12 * ne03) / item_ct1.get_group_range(2);
+
+    const bool did_not_have_any_data   = kbc0 == kbc0_stop;
+    const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
+    const bool did_not_write_last      = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0;
+    if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
+        return;
+    }
+
+    // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
+    const int sequence =  kbc0 /(iter_k*iter_j*iter_z_gqa*ne12);
+    const int z_KV     = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
+    const int zt_gqa   = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
+    const int jt       = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
+
+    const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
+
+    if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) {
+        return;
+    }
+
+    dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid;
+
+    // Load the partial result that needs a fixup:
+    float dst_val = 0.0f;
+    float max_val = 0.0f;
+    float rowsum  = 0.0f;
+    {
+        dst_val = *dst;
+
+        const sycl::float2 tmp = dst_fixup[bidx0 * ncols + jc];
+        max_val                = tmp.x();
+        rowsum                 = tmp.y();
+    }
+
+    // Iterate over previous blocks and compute the combined results.
+    // All SYCL blocks that get here must have a previous block that needs a fixup.
+    int bidx = bidx0 - 1;
+    int kbc_stop = kbc0;
+    while(true) {
+        const int kbc = int64_t(bidx) * (iter_k * iter_j * iter_z_gqa * ne12 * ne03) / item_ct1.get_group_range(2);
+        if (kbc == kbc_stop) { // Did not have any data.
+            bidx--;
+            kbc_stop = kbc;
+            continue;
+        }
+
+        const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid];
+
+        const sycl::float2 tmp = dst_fixup[(item_ct1.get_group_range(2) + bidx) * ncols + jc];
+
+        // Scale the current and new value accumulators depending on the max. values.
+        const float max_val_new = sycl::fmax(max_val, tmp.x());
+
+        const float diff_val = max_val - max_val_new;
+        const float diff_add = tmp.x() - max_val_new;
+
+        const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? sycl::native::exp(diff_val) : 0.0f;
+        const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? sycl::native::exp(diff_add) : 0.0f;
+
+        dst_val = scale_val*dst_val + scale_add*dst_add;
+        rowsum  = scale_val * rowsum + scale_add * tmp.y();
+
+        max_val = max_val_new;
+
+        // If this block started in a previous tile we are done and don't need to combine additional partial results.
+        if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
+            break;
+        }
+        bidx--;
+        kbc_stop = kbc;
+    }
+
+    // Write back final result:
+    *dst = dst_val / rowsum;
+}
+
+template <int D>  // D == head size
+
+static void flash_attn_combine_results(const float * __restrict__ VKQ_parts,
+                                       const sycl::float2 * __restrict__ VKQ_meta,
+                                       float * __restrict__ dst,
+                                       const int parallel_blocks,
+                                       uint8_t * dpct_local) {
+    // Dimension 0: threadIdx.x
+    // Dimension 1: blockIdx.x
+    // Dimension 2: blockIdx.y
+    // Dimension 3: blockIdx.z
+    // Memory layout is permuted with [0, 2, 1, 3]
+
+    auto      item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    const int ne01     = item_ct1.get_group_range(2);
+    const int ne02     = item_ct1.get_group_range(1);
+
+    const int col      = item_ct1.get_group(2);
+    const int head     = item_ct1.get_group(1);
+    const int sequence = item_ct1.get_group(0);
+
+    const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;
+
+    VKQ_parts += j_dst_unrolled * parallel_blocks*D;
+    VKQ_meta  += j_dst_unrolled * parallel_blocks;
+    dst       += j_dst_unrolled *                 D;
+
+    const int tid = item_ct1.get_local_id(2);
+    __builtin_assume(tid < D);
+
+    auto meta = (sycl::float2 *) dpct_local;
+    for (int i = tid; i < 2*parallel_blocks; i += D) {
+        ((float *) meta)[i] = ((const float *)VKQ_meta) [i];
+    }
+
+    item_ct1.barrier(sycl::access::fence_space::local_space);
+
+    float kqmax = meta[0].x();
+    for (int l = 1; l < parallel_blocks; ++l) {
+        kqmax = sycl::max(kqmax, meta[l].x());
+    }
+
+    float VKQ_numerator   = 0.0f;
+    float VKQ_denominator = 0.0f;
+    for (int l = 0; l < parallel_blocks; ++l) {
+        const float KQ_max_scale = sycl::native::exp(meta[l].x() - kqmax);
+
+        VKQ_numerator   += KQ_max_scale * VKQ_parts[l*D + tid];
+        VKQ_denominator += KQ_max_scale * meta[l].y();
+    }
+
+    dst[tid] = VKQ_numerator / VKQ_denominator;
+}
+
+template <fattn_kernel_t fattn_kernel, int warp_size>
+static void lauch_kernel(
+    dpct::dim3 group_range,
+    dpct::dim3 local_range,
+    queue_ptr q,
+    unsigned int local_mem_size,
+    const char* __restrict__ Q,
+    const char* __restrict__ K,
+    const char* __restrict__ V,
+    const char* __restrict__ mask,
+    const char* __restrict__ sinks,
+    const int* __restrict__ KV_max,
+    float* __restrict__ dst,
+    sycl::float2* __restrict__ dst_meta,
+    const float scale,
+    const float max_bias,
+    const float m0,
+    const float m1,
+    const uint32_t n_head_log2,
+    const float logit_softcap,
+    const int32_t ne00,
+    const sycl::uint3 ne01,
+    const int32_t ne02,
+    const int32_t ne03,
+    const int32_t nb01,
+    const int32_t nb02,
+    const int32_t nb03,
+    const int32_t ne10,
+    const int32_t ne11,
+    const int32_t ne12,
+    const int32_t ne13,
+    const int32_t nb11,
+    const int32_t nb12,
+    const int64_t nb13,
+    const int32_t nb21,
+    const int32_t nb22,
+    const int64_t nb23,
+    const int32_t ne31,
+    const int32_t ne32,
+    const int32_t ne33,
+    const int32_t nb31,
+    const int32_t nb32,
+    const int64_t nb33) {
+    GGML_UNUSED(local_mem_size);
+    q->submit([&](sycl::handler &cgh) {
+        cgh.parallel_for(
+            sycl::nd_range<3>(
+                static_cast<sycl::range<3>>(group_range * local_range),
+                static_cast<sycl::range<3>>(local_range)),
+            [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(warp_size)]] {
+                GGML_UNUSED(item_ct1);
+                fattn_kernel(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
+                             max_bias, m0, m1, n_head_log2, logit_softcap, ne00,
+                             ne01, ne02, ne03, nb01, nb02, nb03, ne10, ne11,
+                             ne12, ne13, nb11, nb12, nb13, nb21, nb22, nb23,
+                             ne31, ne32, ne33, nb31, nb32, nb33);
+            });
+    });
+}
+
+template <int DV, int ncols1, int ncols2, fattn_kernel_t fattn_kernel, int warp_size>
+void launch_fattn(
+    ggml_backend_sycl_context & ctx, ggml_tensor * dst, const int nwarps, const size_t nbytes_shared,
+    const int nbatch_fa, const bool need_f16_K, const bool need_f16_V, const bool stream_k) {
+
+    constexpr int ncols = ncols1 * ncols2;
+
+    const ggml_tensor * Q = dst->src[0];
+    const ggml_tensor * K = dst->src[1];
+    const ggml_tensor * V = dst->src[2];
+
+    const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs));
+
+    const ggml_tensor * mask  = dst->src[3];
+    const ggml_tensor * sinks = dst->src[4];
+
+    ggml_tensor * KQV = dst;
+
+    GGML_ASSERT(Q->type == GGML_TYPE_F32);
+    GGML_ASSERT(KQV->type == GGML_TYPE_F32);
+
+    GGML_ASSERT(Q->nb[0] == ggml_element_size(Q));
+    GGML_ASSERT(K->nb[0] == ggml_element_size(K));
+    GGML_ASSERT(V->nb[0] == ggml_element_size(V));
+
+    GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
+
+    ggml_sycl_pool & pool = ctx.pool();
+    dpct::queue_ptr  main_stream = ctx.stream();
+    const int id  = ggml_sycl_get_device();
+    const int nsm = ggml_sycl_info().devices[id].nsm;
+
+    ggml_sycl_pool_alloc<sycl::half>   K_f16(pool);
+    ggml_sycl_pool_alloc<sycl::half>   V_f16(pool);
+    ggml_sycl_pool_alloc<int>    KV_max(pool);
+    ggml_sycl_pool_alloc<float>  dst_tmp(pool);
+    ggml_sycl_pool_alloc<sycl::float2> dst_tmp_meta(pool);
+
+    const char * K_data = (const char *) K->data;
+    size_t nb11 = K->nb[1];
+    size_t nb12 = K->nb[2];
+    size_t nb13 = K->nb[3];
+
+    const char * V_data = (const char *) V->data;
+    size_t nb21 = V->nb[1];
+    size_t nb22 = V->nb[2];
+    size_t nb23 = V->nb[3];
+
+    if (need_f16_K && K->type != GGML_TYPE_F16) {
+        const size_t bs = ggml_blck_size(K->type);
+        const size_t ts = ggml_type_size(K->type);
+
+        K_f16.alloc(ggml_nelements(K));
+        if (ggml_is_contiguously_allocated(K)) {
+            to_fp16_sycl_t to_fp16 = ggml_get_to_fp16_sycl(K->type, dst);
+            to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
+
+            nb11 = nb11 * bs * sizeof(sycl::half) / ts;
+            nb12 = nb12 * bs * sizeof(sycl::half) / ts;
+            nb13 = nb13 * bs * sizeof(sycl::half) / ts;
+        } else {
+            GGML_ASSERT(K->nb[0] == ts);
+            to_fp16_nc_sycl_t to_fp16 = ggml_get_to_fp16_nc_sycl(K->type);
+            const int64_t s01 = nb11 / ts;
+            const int64_t s02 = nb12 / ts;
+            const int64_t s03 = nb13 / ts;
+            to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream);
+
+            nb11 = K->ne[0] * sizeof(sycl::half);
+            nb12 = K->ne[1] * nb11;
+            nb13 = K->ne[2] * nb12;
+        }
+        K_data = (char *) K_f16.ptr;
+    }
+
+    if (need_f16_V && V->type != GGML_TYPE_F16) {
+        if (V_is_K_view) {
+            V_data = K_data;
+            nb21   = nb11;
+            nb22   = nb12;
+            nb23   = nb13;
+        } else {
+            const size_t bs = ggml_blck_size(V->type);
+            const size_t ts = ggml_type_size(V->type);
+
+            V_f16.alloc(ggml_nelements(V));
+            if (ggml_is_contiguously_allocated(V)) {
+                to_fp16_sycl_t to_fp16 = ggml_get_to_fp16_sycl(V->type, dst);
+                to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
+                V_data = (char *) V_f16.ptr;
+
+                nb21 = nb21 * bs * sizeof(sycl::half) / ts;
+                nb22 = nb22 * bs * sizeof(sycl::half) / ts;
+                nb23 = nb23 * bs * sizeof(sycl::half) / ts;
+            } else {
+                GGML_ASSERT(V->nb[0] == ts);
+                to_fp16_nc_sycl_t to_fp16 = ggml_get_to_fp16_nc_sycl(V->type);
+                const int64_t s01 = nb21 / ts;
+                const int64_t s02 = nb22 / ts;
+                const int64_t s03 = nb23 / ts;
+                to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
+
+                nb21 = V->ne[0] * sizeof(sycl::half);
+                nb22 = V->ne[1] * nb21;
+                nb23 = V->ne[2] * nb22;
+            }
+            V_data = (char *) V_f16.ptr;
+        }
+    }
+
+    const int ntiles_x     = ((Q->ne[1] + ncols1 - 1) / ncols1);
+    const int gqa_ratio    = Q->ne[2] / K->ne[2];
+    const int ntiles_z_gqa = ((gqa_ratio + ncols2 - 1) / ncols2);
+    const int ntiles_total = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3];
+
+    // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
+    // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
+    //     multiple sequences of possibly different lengths.
+    if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
+        const int s31 = mask->nb[1] / sizeof(sycl::half2);
+        const int s33 = mask->nb[3] / sizeof(sycl::half2);
+
+        const dpct::dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1);
+        const dpct::dim3 block_dim_KV_max(FATTN_KQ_STRIDE / 2, 1, 1);
+
+        const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y;
+        const int iter_k = K->ne[1] / FATTN_KQ_STRIDE;
+
+        KV_max.alloc(ne_KV_max);
+        {
+            dpct::has_capability_or_fail(main_stream->get_device(), { sycl::aspect::fp16 });
+
+            main_stream->submit([&](sycl::handler & cgh) {
+                sycl::local_accessor<int, 1> buf_iw_acc_ct1(sycl::range<1>(warp_size), cgh);
+
+                auto mask_data_ct0  = (const sycl::half2 *) mask->data;
+                auto KV_max_ptr_ct1 = KV_max.ptr;
+
+                cgh.parallel_for(sycl::nd_range<3>(blocks_num_KV_max * block_dim_KV_max, block_dim_KV_max),
+                                 [=](sycl::nd_item<3> item_ct1) {
+                                     GGML_UNUSED(item_ct1);
+                                     flash_attn_mask_to_KV_max<ncols1, warp_size>(
+                                         mask_data_ct0, KV_max_ptr_ct1, iter_k, s31, s33,
+                                         buf_iw_acc_ct1.get_multi_ptr<sycl::access::decorated::no>().get());
+                                 });
+            });
+        }
+        SYCL_CHECK(0);
+    }
+
+    const dpct::dim3 block_dim(warp_size, nwarps, 1);
+
+    // Max. number of active blocks limited by occupancy.
+    int max_blocks_per_sm = ggml_sycl_info().devices[id].max_wg_per_cu;
+    int parallel_blocks = max_blocks_per_sm;
+    dpct::dim3 blocks_num;
+    if (stream_k) {
+        // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
+        const int max_blocks = max_blocks_per_sm*nsm;
+        const int nblocks_stream_k = max_blocks;
+        const bool use_stream_k = true;
+
+        blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
+        blocks_num.y = 1;
+        blocks_num.z = 1;
+
+        if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
+            dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2)));
+        }
+    } else {
+        const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size.
+
+        // parallel_blocks must not be larger than what the tensor size allows:
+        parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
+        // todo fix the hard code change
+        // parallel_blocks = ntiles_KQ;
+
+        // If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects.
+        // Test whether parallel_blocks can be set to a higher value for better efficiency.
+        const int blocks_per_wave = nsm * max_blocks_per_sm;
+        int nwaves_best = 0;
+        int efficiency_percent_best = 0;
+        for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) {
+            const int nblocks_total = ntiles_total * parallel_blocks_test;
+            const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave;
+            const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave);
+
+            // Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead.
+            if (efficiency_percent_best >= 95 && nwaves > nwaves_best) {
+                break;
+            }
+
+            if (efficiency_percent > efficiency_percent_best) {
+                nwaves_best = nwaves;
+                efficiency_percent_best = efficiency_percent;
+                parallel_blocks = parallel_blocks_test;
+            }
+        }
+
+        blocks_num.x = ntiles_x;
+        blocks_num.y = parallel_blocks;
+        blocks_num.z = ntiles_z_gqa*K->ne[2]*Q->ne[3];
+
+        if (parallel_blocks > 1) {
+            dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
+            dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
+        }
+    }
+
+    float scale         = 1.0f;
+    float max_bias      = 0.0f;
+    float logit_softcap = 0.0f;
+
+    memcpy(&scale,         (const float *) KQV->op_params + 0, sizeof(float));
+    memcpy(&max_bias,      (const float *) KQV->op_params + 1, sizeof(float));
+    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
+    if (logit_softcap != 0.0f) {
+        scale /= logit_softcap;
+    }
+
+    const uint32_t n_head      = Q->ne[2];
+    const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head))));
+
+    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+    // TODO other tensor dimensions after removal of WMMA kernel:
+    const sycl::uint3 ne01 = init_fastdiv_values(Q->ne[1]);
+
+    GGML_ASSERT(block_dim.x % warp_size == 0);
+
+    lauch_kernel<fattn_kernel, warp_size>(
+        blocks_num, block_dim, main_stream, (unsigned int) nbytes_shared, (const char *) Q->data, K_data, V_data,
+        mask ? ((const char *) mask->data) : nullptr, sinks ? ((const char *) sinks->data) : nullptr, KV_max.ptr,
+        !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, (sycl::float2 *)dst_tmp_meta.ptr, scale, max_bias, m0, m1,
+        n_head_log2, logit_softcap, Q->ne[0], ne01, Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3], K->ne[0],
+        K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13, nb21, nb22, nb23, mask ? mask->ne[1] : 0,
+        mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0, mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0,
+        mask ? mask->nb[3] : 0);
+    SYCL_CHECK(0);
+
+    if (stream_k) {
+        if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
+            const dpct::dim3 block_dim_combine(DV, 1, 1);
+            const dpct::dim3 blocks_num_combine = { blocks_num.x, ncols1, ncols2 };
+
+            main_stream->submit([&](sycl::handler & cgh) {
+                auto KQV_data_ct0         = (float *) KQV->data;
+                auto dst_tmp_meta_ptr_ct1 = dst_tmp_meta.ptr;
+                auto Q_ne_ct2             = Q->ne[1];
+                auto Q_ne_ct3             = Q->ne[2];
+                auto Q_ne_ct4             = Q->ne[3];
+                auto K_ne_ct5             = K->ne[1];
+                auto K_ne_ct6             = K->ne[2];
+
+                cgh.parallel_for(sycl::nd_range<3>(blocks_num_combine * block_dim_combine, block_dim_combine),
+                                 [=](sycl::nd_item<3> item_ct1) {
+                                     GGML_UNUSED(item_ct1);
+                                     flash_attn_stream_k_fixup<DV, ncols1, ncols2>(KQV_data_ct0, dst_tmp_meta_ptr_ct1,
+                                                                                   Q_ne_ct2, Q_ne_ct3, Q_ne_ct4,
+                                                                                   K_ne_ct5, K_ne_ct6, nbatch_fa);
+                                 });
+            });
+        }
+    } else if (parallel_blocks > 1) {
+        const dpct::dim3 block_dim_combine(DV, 1, 1);
+        const dpct::dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);
+        const size_t     nbytes_shared_combine = parallel_blocks * sizeof(sycl::float2);
+        main_stream->submit([&](sycl::handler & cgh) {
+            sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(sycl::range<1>(nbytes_shared_combine), cgh);
+
+            auto dst_tmp_ptr_ct0      = dst_tmp.ptr;
+            auto dst_tmp_meta_ptr_ct1 = dst_tmp_meta.ptr;
+            auto KQV_data_ct2         = (float *) KQV->data;
+
+            cgh.parallel_for(sycl::nd_range<3>(blocks_num_combine * block_dim_combine, block_dim_combine),
+                             [=](sycl::nd_item<3> item_ct1) {
+                                 GGML_UNUSED(item_ct1);
+                                 flash_attn_combine_results<DV>(
+                                     dst_tmp_ptr_ct0, dst_tmp_meta_ptr_ct1, KQV_data_ct2, parallel_blocks,
+                                     dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>().get());
+                             });
+        });
+    }
+    SYCL_CHECK(0);
+}
diff --git a/ggml/src/ggml-sycl/fattn-tile.cpp b/ggml/src/ggml-sycl/fattn-tile.cpp
new file mode 100644 (file)
index 0000000..9d4f019
--- /dev/null
@@ -0,0 +1,55 @@
+#include <sycl/sycl.hpp>
+#include <sycl/ext/oneapi/work_group_static.hpp>
+#include "dpct/helper.hpp"
+#include "common.hpp"
+#include "fattn-common.hpp"
+#include "fattn-tile.hpp"
+#include <cmath>
+#include <float.h>
+namespace syclex = sycl::ext::oneapi::experimental;
+
+void ggml_sycl_flash_attn_ext_tile(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * K = dst->src[1];
+    const ggml_tensor * V = dst->src[2];
+    switch (K->ne[0]) {
+        case  40: {
+            GGML_ASSERT(V->ne[0] == K->ne[0]);
+            ggml_sycl_flash_attn_ext_tile_case< 40,  40>(ctx, dst);
+        } break;
+        case  64: {
+            GGML_ASSERT(V->ne[0] == K->ne[0]);
+            ggml_sycl_flash_attn_ext_tile_case< 64,  64>(ctx, dst);
+        } break;
+        case  72: {
+            GGML_ASSERT(V->ne[0] == K->ne[0]);
+            ggml_sycl_flash_attn_ext_tile_case< 72,  72>(ctx, dst);
+        } break;
+        case  80: {
+            GGML_ASSERT(V->ne[0] == K->ne[0]);
+            ggml_sycl_flash_attn_ext_tile_case< 80,  80>(ctx, dst);
+        } break;
+        case  96: {
+            GGML_ASSERT(V->ne[0] == K->ne[0]);
+            ggml_sycl_flash_attn_ext_tile_case< 96,  96>(ctx, dst);
+        } break;
+        case 112: {
+            GGML_ASSERT(V->ne[0] == K->ne[0]);
+            ggml_sycl_flash_attn_ext_tile_case<112, 112>(ctx, dst);
+        } break;
+        case 128: {
+            GGML_ASSERT(V->ne[0] == K->ne[0]);
+            ggml_sycl_flash_attn_ext_tile_case<128, 128>(ctx, dst);
+        } break;
+        case 256: {
+            GGML_ASSERT(V->ne[0] == K->ne[0]);
+            ggml_sycl_flash_attn_ext_tile_case<256, 256>(ctx, dst);
+        } break;
+        case 576: {
+            GGML_ASSERT(V->ne[0] == 512);
+            ggml_sycl_flash_attn_ext_tile_case<576, 512>(ctx, dst);
+        } break;
+        default: {
+            GGML_ABORT("Unsupported head size");
+        } break;
+    }
+}
diff --git a/ggml/src/ggml-sycl/fattn-tile.hpp b/ggml/src/ggml-sycl/fattn-tile.hpp
new file mode 100644 (file)
index 0000000..29fd0f8
--- /dev/null
@@ -0,0 +1,1338 @@
+#include <sycl/sycl.hpp>
+#include <sycl/ext/oneapi/work_group_static.hpp>
+#include "dpct/helper.hpp"
+#include "common.hpp"
+#include "fattn-common.hpp"
+
+#include <cmath>
+#include <float.h>
+
+namespace syclex = sycl::ext::oneapi::experimental;
+
+#define GGML_SYCL_FATTN_TILE_CONFIG_CASE(DKQ_, DV_, ncols_, nthreads, occupancy, nbatch_fa, nbatch_K) \
+    if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) {                                          \
+        static_assert((nthreads)          <= 512, "bad nthreads");                                    \
+        static_assert((occupancy)         <=   8, "bad occupancy");                                   \
+        static_assert((nbatch_fa)         <= 256, "bad nbatch_fa");                                   \
+        static_assert((nbatch_K)          <= 256, "bad nbatch_K");                                    \
+        return ((nthreads) << 0) | ((occupancy) << 10) | ((nbatch_fa) << 14) | ((nbatch_K) << 23);    \
+    }                                                                                                 \
+
+static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp16(const int DKQ, const int DV, const int ncols) {
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40,  2,  64, 2,  64,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40,  4, 128, 2,  64,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40,  8, 256, 2,  64,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40, 16, 256, 2,  64,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40, 32, 256, 2,  64,  40)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64,  2,  64, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64,  4, 128, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64,  8, 256, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64, 16, 256, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64, 32, 256, 2,  64,  64)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72,  2,  64, 2,  64,  72)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72,  4, 128, 2,  64,  72)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72,  8, 256, 2,  64,  72)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72, 16, 256, 2,  64,  72)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72, 32, 256, 2,  64,  72)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80,  2,  64, 2,  64,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80,  4, 128, 2,  64,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80,  8, 256, 2,  64,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80, 16, 256, 2,  64,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80, 32, 256, 2,  64,  40)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96,  2,  64, 2,  64,  48)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96,  4, 128, 2,  64,  48)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96,  8, 256, 2,  64,  48)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96, 16, 256, 2,  64,  48)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96, 32, 256, 2,  64,  48)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112,  2,  64, 2,  64,  56)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112,  4, 128, 2,  64,  56)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112,  8, 256, 2,  64,  56)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2,  64,  56)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2,  64,  56)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128,  2,  64, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128,  4, 128, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128,  8, 256, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2,  64,  64)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256,  2,  64, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256,  4, 128, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256,  8, 256, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2,  64,  64)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2,  64,  64)
+
+    return 0;
+}
+
+static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp32(const int DKQ, const int DV, const int ncols) {
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40,  2,  64, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40,  4, 128, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40,  8, 256, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40, 16, 256, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40, 32, 256, 2,  32,  40)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64,  2, 128, 3,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64,  4, 128, 3,  32,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64,  8, 128, 3,  32,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64, 16, 128, 3,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64, 32, 256, 2,  64,  64)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72,  2,  64, 2,  32,  72)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72,  4, 128, 2,  32,  72)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72,  8, 256, 2,  32,  72)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72, 16, 256, 2,  32,  72)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72, 32, 256, 2,  32,  72)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80,  2,  64, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80,  4, 128, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80,  8, 256, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80, 16, 256, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80, 32, 256, 2,  32,  40)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96,  2,  64, 2,  32,  48)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96,  4, 128, 2,  32,  48)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96,  8, 256, 2,  32,  48)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96, 16, 256, 2,  32,  48)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96, 32, 256, 2,  32,  48)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112,  2,  64, 2,  32,  56)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112,  4, 128, 2,  32,  56)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112,  8, 256, 2,  32,  56)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2,  32,  56)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2,  32,  56)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128,  2, 128, 3,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128,  4, 128, 3,  32, 128)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128,  8, 128, 3,  64, 128)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 128, 3,  32, 128)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2,  64,  64)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256,  2, 128, 3,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256,  4, 128, 3,  32,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256,  8, 256, 2,  32, 256)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2,  32, 128)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2,  32,  64)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  32,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  32,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2,  32,  64)
+
+    return 0;
+}
+
+static constexpr uint32_t ggml_sycl_fattn_tile_get_config_amd(const int DKQ, const int DV, const int ncols) {
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40,  2,  64, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40,  4, 128, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40,  8, 256, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40, 16, 256, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40, 32, 256, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40, 64, 256, 2,  32,  40)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64,  2,  64, 3,  32,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64,  4, 128, 3,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64,  8, 128, 2,  32,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64, 16, 256, 2, 128,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64, 32, 256, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64, 64, 256, 2,  64,  64)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72,  2,  64, 2,  32,  72)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72,  4, 128, 2,  32,  72)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72,  8, 256, 2,  32,  72)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72, 16, 256, 2,  32,  72)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72, 32, 256, 2,  32,  72)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72, 64, 256, 2,  32,  72)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80,  2,  64, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80,  4, 128, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80,  8, 256, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80, 16, 256, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80, 32, 256, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80, 64, 256, 2,  32,  40)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96,  2,  64, 2,  32,  48)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96,  4, 128, 2,  32,  48)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96,  8, 256, 2,  32,  48)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96, 16, 256, 2,  32,  48)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96, 32, 256, 2,  32,  48)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96, 64, 256, 2,  32,  48)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112,  2,  64, 2,  32,  56)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112,  4, 128, 2,  32,  56)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112,  8, 256, 2,  32,  56)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2,  32,  56)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2,  32,  56)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2,  32,  56)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128,  2, 256, 2, 128,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128,  4, 128, 2,  64, 128)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128,  8, 256, 2,  64, 128)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2,  64, 128)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 2,  64,  32)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256,  2, 256, 2, 128,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256,  4, 256, 2,  64, 128)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256,  8, 256, 2,  64, 128)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2,  32, 128)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2,  32, 128)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128,  64)
+
+    return 0;
+}
+
+static constexpr uint32_t ggml_sycl_fattn_tile_get_config_amd_rdna(const int DKQ, const int DV, const int ncols) {
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40,  2,  64, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40,  4, 128, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40,  8, 256, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40, 16, 256, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40, 32, 256, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40,  40, 64, 256, 2,  32,  40)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64,  2,  64, 8,  32,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64,  4,  64, 8,  32,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64,  8, 128, 5, 128,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64, 16, 128, 5, 128,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64, 32, 128, 4,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64,  64, 64, 128, 5,  64,  64)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72,  2,  64, 2,  32,  72)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72,  4, 128, 2,  32,  72)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72,  8, 256, 2,  32,  72)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72, 16, 256, 2,  32,  72)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72, 32, 256, 2,  32,  72)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72,  72, 64, 256, 2,  32,  72)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80,  2,  64, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80,  4, 128, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80,  8, 256, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80, 16, 256, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80, 32, 256, 2,  32,  40)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80,  80, 64, 256, 2,  32,  40)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96,  2,  64, 2,  32,  48)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96,  4, 128, 2,  32,  48)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96,  8, 256, 2,  32,  48)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96, 16, 256, 2,  32,  48)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96, 32, 256, 2,  32,  48)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96,  96, 64, 256, 2,  32,  48)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112,  2,  64, 2,  32,  56)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112,  4, 128, 2,  32,  56)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112,  8, 256, 2,  32,  56)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2,  32,  56)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2,  32,  56)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2,  32,  56)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128,  2,  64, 8,  32,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128,  4, 128, 8,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128,  8, 128, 8,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 3, 128, 128)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 3, 128,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 3,  64,  64)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256,  2,  64, 8,  32,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256,  4, 128, 6,  32, 256)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256,  8, 128, 6,  32, 256)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5,  32, 256)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3,  64, 128)
+
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512,  4, 128, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512,  8, 256, 2,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4,  64,  64)
+    GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128,  64)
+
+    return 0;
+}
+
+static constexpr uint32_t ggml_sycl_fattn_tile_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
+    if(fast_fp16_available(cc))
+        return ggml_sycl_fattn_tile_get_config_fp16(DKQ, DV, ncols);
+    else
+        return ggml_sycl_fattn_tile_get_config_fp32(DKQ, DV, ncols);
+}
+
+static constexpr uint32_t ggml_sycl_fattn_tile_get_config(const int DKQ, const int DV, const int ncols) {
+#ifdef SYCL_FAST_FP16
+    return ggml_sycl_fattn_tile_get_config_fp16(DKQ, DV, ncols);
+#else
+    return ggml_sycl_fattn_tile_get_config_fp32(DKQ, DV, ncols);
+#endif // SYCL_FAST_FP16
+}
+
+static int ggml_sycl_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols, const int cc) {
+    return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 0) & ((1 << 10) - 1);
+}
+
+static constexpr int ggml_sycl_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols) {
+    return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols) >> 0) & ((1 << 10) - 1);
+}
+
+static int ggml_sycl_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols, const int cc) {
+    return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 10) & ((1 << 4) - 1);
+}
+
+static constexpr int ggml_sycl_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols) {
+    return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols) >> 10) & ((1 << 4) - 1);
+}
+
+static int ggml_sycl_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols, const int cc) {
+    return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 14) & ((1 << 9) - 1);
+}
+
+static constexpr int ggml_sycl_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols) {
+    return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols) >> 14) & ((1 << 9) - 1);
+}
+
+static int ggml_sycl_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols, const int cc) {
+    return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 23) & ((1 << 9) - 1);
+}
+
+static constexpr int ggml_sycl_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols) {
+    return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols) >> 23) & ((1 << 9) - 1);
+}
+
+template <int warp_size, int nwarps, int I, int J, int J_padding, bool oob_check>
+static __dpct_inline__ void flash_attn_tile_load_tile(const sycl::half2 * const __restrict__ KV,
+                                                      sycl::half2 * const __restrict__ tile_KV,
+                                                      const int stride_KV,
+                                                      const int i_sup) {
+    constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();
+    constexpr int cpy_ne = cpy_nb / 4;
+
+    auto load = [&] (const int n) {
+        auto      item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+        const int stride_j = warp_size >> n;
+
+        if (stride_j == 0) {
+            return;
+        }
+
+        const int j0_start = stride_j == warp_size ? 0 : ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (2*stride_j);
+        const int j0_stop  =                             ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (1*stride_j);
+        const int stride_i = warp_size / stride_j;
+
+        if (j0_start == j0_stop) {
+            return;
+        }
+
+#pragma unroll
+        for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) {
+            const int i = i0 + item_ct1.get_local_id(1) * stride_i +
+                          (stride_j == warp_size ? 0 : item_ct1.get_local_id(2) / stride_j);
+
+            if (i0 + nwarps*stride_i <= I || i < I) {
+#pragma unroll
+                for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) {
+                    const int j = j0 * cpy_ne + (stride_j == warp_size ? item_ct1.get_local_id(2) :
+                                                                         item_ct1.get_local_id(2) % stride_j) *
+                                                    cpy_ne;
+
+                    const __dpct_align__(16) sycl::half2 zero[cpy_ne] = {
+                        { 0.0f, 0.0f }
+                    };
+                    ggml_sycl_memcpy_1<cpy_nb>(
+                        tile_KV + i*(J/2 + J_padding) + j,
+                        !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
+                }
+            }
+        }
+    };
+    // 1: max 64*16=512 bytes, 512 half
+    // 2: max 32*16=512 bytes, 256 half
+    // 3: max 16*16=256 bytes, 128 half
+    // 4: max  8*16=128 bytes,  64 half
+    // 5: max  4*16= 64 bytes,  32 half
+    // 6: max  2*16= 32 bytes,  16 half
+    // 7: max  1*16= 16 bytes,   8 half
+    static_assert(J % 8 == 0, "bad J");
+    static_assert((J/2) % cpy_ne == 0, "bad J");
+    ggml_sycl_unroll<7>{}(load);
+}
+
+template <int warp_size, int nwarps, int I, int J, int J_padding, bool oob_check>
+static __dpct_inline__ void flash_attn_tile_load_tile(const sycl::half2 * const __restrict__ KV,
+                                                      float * const __restrict__ tile_KV,
+                                                      const int stride_KV,
+                                                      const int i_sup) {
+    constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();
+    constexpr int cpy_ne = cpy_nb / 4;
+
+    auto load = [&] (const int n) {
+        auto      item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+        const int stride_j = warp_size >> n;
+
+        if (stride_j == 0) {
+            return;
+        }
+
+        const int j0_start = stride_j == warp_size ? 0 : (J/cpy_ne) - (J/cpy_ne) % (2*stride_j);
+        const int j0_stop  =                             (J/cpy_ne) - (J/cpy_ne) % (1*stride_j);
+        const int stride_i = warp_size / stride_j;
+
+        if (j0_start == j0_stop) {
+            return;
+        }
+
+#pragma unroll
+        for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) {
+            const int i = i0 + item_ct1.get_local_id(1) * stride_i +
+                          (stride_j == warp_size ? 0 : item_ct1.get_local_id(2) / stride_j);
+
+            if (i0 + nwarps*stride_i <= I || i < I) {
+#pragma unroll
+                for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) {
+                    const int j = j0 * (cpy_ne / 2) + (stride_j == warp_size ? item_ct1.get_local_id(2) :
+                                                                               item_ct1.get_local_id(2) % stride_j) *
+                                                          (cpy_ne / 2);
+
+                    const sycl::half2 zero[cpy_ne / 2] = {
+                        { 0.0f, 0.0f }
+                    };
+                    __dpct_align__(16) sycl::half2 tmp_h2[cpy_ne / 2];
+                    ggml_sycl_memcpy_1<sizeof(tmp_h2)>(
+                        tmp_h2, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
+
+                    __dpct_align__(16) sycl::float2 tmp_f2[cpy_ne / 2];
+#pragma unroll
+                    for (int l = 0; l < cpy_ne/2; ++l) {
+                        tmp_f2[l] = tmp_h2[l].template convert<float, sycl::rounding_mode::automatic>();
+                    }
+                    ggml_sycl_memcpy_1<sizeof(tmp_f2)>(tile_KV + i*(J + J_padding) + 2*j, tmp_f2);
+                }
+            }
+        }
+    };
+    // 1: max 32*16=512 bytes, 128 float
+    // 2: max 16*16=256 bytes,  64 float
+    // 3: max  8*16=128 bytes,  32 float
+    // 4: max  4*16= 64 bytes,  16 float
+    // 5: max  2*16= 32 bytes,   8 float
+    static_assert(J % 8 == 0, "bad J");
+    static_assert(J % cpy_ne == 0, "bad J");
+    ggml_sycl_unroll<5>{}(load);
+}
+
+// Function that performs a single iteration in for the KQ matrix multiplication:
+template <int  warp_size,
+          int  nwarps,
+          int  ncols1,
+          int  ncols2,
+          int  DKQ,
+          int  nbatch_fa,
+          int  nbatch_K,
+          bool use_logit_softcap,
+          bool oob_check,
+          typename T_vec_dot>
+static __dpct_inline__ void flash_attn_tile_iter_KQ(T_vec_dot * const Q_tmp,
+                                                    const sycl::half2 * const __restrict__ K_h2,
+                                                    T_vec_dot * const KV_tmp,
+                                                    const int         stride_K2,
+                                                    const int         k_VKQ_0,
+                                                    const int         k_VKQ_sup,
+                                                    const int         k_KQ_0,
+                                                    float *           KQ_acc) {
+    auto          item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    constexpr int cpy_nb   = ggml_sycl_get_max_cpy_bytes();
+    constexpr int cpy_ne = cpy_nb / 4;
+
+    constexpr int ncols = ncols1*ncols2;
+    constexpr int cpw   = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp
+    constexpr int np    = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column
+
+    flash_attn_tile_load_tile<warp_size, nwarps, nbatch_fa, nbatch_K, cpy_ne, oob_check>
+        (K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup);
+    item_ct1.barrier();
+
+#ifdef SYCL_FAST_FP16
+    static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K");
+#pragma unroll
+    for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K/2; k_KQ_1 += cpy_ne) {
+        __dpct_align__(16) sycl::half2 K_k[nbatch_fa / (np * warp_size)][cpy_ne];
+        __dpct_align__(16) sycl::half2 Q_k[cpw][cpy_ne];
+#else
+    static_assert(nbatch_K % cpy_ne == 0, "bad nbatch_K");
+#pragma unroll
+    for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K; k_KQ_1 += cpy_ne) {
+        __dpct_align__(16) float K_k[nbatch_fa/(np*warp_size)][cpy_ne];
+        __dpct_align__(16) float Q_k[cpw][cpy_ne];
+#endif // SYCL_FAST_FP16
+
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {
+            const int i_KQ = i_KQ_0 + (item_ct1.get_local_id(1) % np) * warp_size + item_ct1.get_local_id(2);
+
+#ifdef SYCL_FAST_FP16
+            ggml_sycl_memcpy_1<cpy_nb>(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K/2 + cpy_ne) + k_KQ_1]);
+#else
+            ggml_sycl_memcpy_1<cpy_nb>(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K   + cpy_ne) + k_KQ_1]);
+#endif // SYCL_FAST_FP16
+        }
+#pragma unroll
+        for (int jc0 = 0; jc0 < cpw; ++jc0) {
+            const int jc = jc0 + (item_ct1.get_local_id(1) / np) * cpw;
+
+#ifdef SYCL_FAST_FP16
+            ggml_sycl_memcpy_1<cpy_nb>(&Q_k[jc0], &Q_tmp[jc*(DKQ/2) + k_KQ_0/2 + k_KQ_1]);
+#else
+            ggml_sycl_memcpy_1<cpy_nb>(&Q_k[jc0], &Q_tmp[jc* DKQ    + k_KQ_0   + k_KQ_1]);
+#endif // SYCL_FAST_FP16
+        }
+
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {
+#pragma unroll
+            for (int jc0 = 0; jc0 < cpw; ++jc0) {
+#pragma unroll
+                for (int k = 0; k < cpy_ne; ++k) {
+                    ggml_sycl_mad(KQ_acc[i_KQ_0/(np*warp_size)*cpw + jc0], K_k[i_KQ_0/(np*warp_size)][k], Q_k[jc0][k]);
+                }
+            }
+        }
+    }
+
+    if (k_KQ_0 + nbatch_K < DKQ) {
+        item_ct1.barrier();  // Sync not needed on last iteration.
+    }
+}
+
+// Function that performs a single iteration of the main loop over up to nbatch_fa tokens.
+template <int  warp_size,
+          int  nwarps,
+          int  ncols1,
+          int  ncols2,
+          int  DKQ,
+          int  DV,
+          int  nbatch_fa,
+          int  nbatch_K,
+          bool use_logit_softcap,
+          bool oob_check,
+          typename T_vec_dot,
+          typename T_KQ,
+          typename T_acc>
+/*
+The total declared local variable size in device function flash_attn_tile_iter exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure.
+*/
+static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp,
+                                                 const sycl::half2 * const __restrict__ K_h2,
+                                                 const sycl::half2 * const __restrict__ V_h2,
+                                                 const sycl::half * const __restrict__ mask,
+                                                 const sycl::uint3 ne01,
+                                                 const float       logit_softcap,
+                                                 const float       slope,
+                                                 T_KQ * const      KQ,
+                                                 T_vec_dot * const KV_tmp,
+                                                 const int         stride_K2,
+                                                 const int         stride_V2,
+                                                 const int         stride_mask,
+                                                 float * const     KQ_max,
+                                                 float * const     KQ_sum,
+                                                 T_acc * const     VKQ,
+                                                 const int         k_VKQ_0,
+                                                 const int         k_VKQ_max,
+                                                 const int         col_Q_0,
+                                                 float *           KQ_max_new_shared) {
+    auto          item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    constexpr int cpy_nb   = ggml_sycl_get_max_cpy_bytes();
+    constexpr int cpy_ne = cpy_nb / 4;
+
+    constexpr int ncols = ncols1*ncols2;
+    constexpr int cpw   = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp
+    constexpr int np    = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column
+
+    constexpr int DVp = (DV + 2*warp_size - 1) & ~(2*warp_size - 1); // DV padded to multiple of 2*warp_size.
+
+#ifdef SYCL_FAST_FP16
+    constexpr int KQ_cs = cpw < 2*cpy_ne ? cpw : 2*cpy_ne;
+#else
+    constexpr int KQ_cs = cpw < 1*cpy_ne ? cpw : 1*cpy_ne;
+#endif // SYCL_FAST_FP16
+    static_assert(cpw % KQ_cs == 0, "bad KQ_cs");
+    const int k_VKQ_sup = k_VKQ_max - k_VKQ_0; // k supremum, only smaller k values have valid KV data
+
+    float KQ_max_new[cpw];
+#pragma unroll
+    for (int jc0 = 0; jc0 < cpw; ++jc0) {
+        KQ_max_new[jc0] = KQ_max[jc0];
+    }
+
+    float KQ_acc[nbatch_fa/(np*warp_size) * cpw] = {0.0f}; // Accumulators for KQ matrix multiplication.
+
+    // KQ = K @ Q matrix multiplication:
+    constexpr int nbatch_K_last = DKQ % nbatch_K;
+#pragma unroll
+    for (int k_KQ_0 = 0; k_KQ_0 < DKQ - nbatch_K_last; k_KQ_0 += nbatch_K) {
+        flash_attn_tile_iter_KQ<warp_size, nwarps, ncols1, ncols2, DKQ, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>(
+            Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc);
+    }
+    if (nbatch_K_last > 0) {
+        constexpr int k_KQ_0 = DKQ - nbatch_K_last;
+        flash_attn_tile_iter_KQ<warp_size, nwarps, ncols1, ncols2, DKQ, nbatch_fa, nbatch_K_last, use_logit_softcap, oob_check>(
+            Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc);
+    }
+
+    // Apply logit softcap + mask, update KQ_max:
+#pragma unroll
+    for (int jc0 = 0; jc0 < cpw; ++jc0) {
+        const int j = fastmodulo(col_Q_0 + (jc0 + (item_ct1.get_local_id(1) / np) * cpw) / ncols2, ne01);
+
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {
+            const int i_KQ = i_KQ_0 + (item_ct1.get_local_id(1) % np) * warp_size + item_ct1.get_local_id(2);
+
+#if defined(SYCL_FAST_FP16) && !defined(GGML_SYCL_F16)
+            // Without the v_dot2_f32_f16 instruction there is a higher risk of numerical overflow in the KQ calculation.
+            // Therefore, scale down Q values and apply the inverse scale the FP32 KQ values afterwards again.
+            KQ_acc[i_KQ_0/(np*warp_size)*cpw + jc0] *= 4.0f;
+#endif // defined(SYCL_FAST_FP16) && !defined(GGML_SYCL_F16)
+
+            if (use_logit_softcap) {
+                KQ_acc[(i_KQ_0 / (np * warp_size)) * cpw + jc0] =
+                    logit_softcap * sycl::tanh((float) KQ_acc[(i_KQ_0 / (np * warp_size)) * cpw + jc0]);
+            }
+
+            if (!oob_check || i_KQ < k_VKQ_sup) {
+                KQ_acc[(i_KQ_0 / (np * warp_size)) * cpw + jc0] +=
+                    (ncols2 > 1 || mask) ? slope * sycl::vec<sycl::half, 1>(mask[j * stride_mask + k_VKQ_0 + i_KQ])
+                                                       .convert<float, sycl::rounding_mode::automatic>()[0] :
+                                           0.0f;
+
+                KQ_max_new[jc0] =
+                    sycl::fmax((float) KQ_max_new[jc0],
+                               (float) (KQ_acc[(i_KQ_0 / (np * warp_size)) * cpw + jc0] + FATTN_KQ_MAX_OFFSET));
+            }
+        }
+
+        KQ_max_new[jc0] = warp_reduce_max<warp_size>(KQ_max_new[jc0]);
+    }
+
+    if constexpr (np == 1) {
+        item_ct1.barrier();
+    } else {
+        static_assert(cpw == 1, "bad cpw");
+
+        if (item_ct1.get_local_id(2) == 0) {
+            KQ_max_new_shared[item_ct1.get_local_id(1)] = KQ_max_new[0];
+        }
+        item_ct1.barrier();
+        KQ_max_new[0] = KQ_max_new_shared[(item_ct1.get_local_id(1) & ~(np - 1)) + item_ct1.get_local_id(2) % np];
+        KQ_max_new[0] = warp_reduce_max<np>(KQ_max_new[0]);
+    }
+
+    // Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators:
+#pragma unroll
+    for (int jc0 = 0; jc0 < cpw; jc0 += KQ_cs) {
+#ifdef SYCL_FAST_FP16
+        __dpct_align__(16) sycl::half tmp[nbatch_fa / (np * warp_size)][KQ_cs];
+#else
+        __dpct_align__(16) float tmp[nbatch_fa/(np*warp_size)][KQ_cs];
+#endif // SYCL_FAST_FP16
+
+#pragma unroll
+        for (int jc1 = 0; jc1 < KQ_cs; ++jc1) {
+            const int jc = jc0 + jc1;
+
+            const float KQ_max_scale = sycl::native::exp((float) (KQ_max[jc] - KQ_max_new[jc]));
+            KQ_max[jc] = KQ_max_new[jc];
+
+            float KQ_sum_add = 0.0f;
+#pragma unroll
+            for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {
+                const float val =
+                    !oob_check || i0 + (item_ct1.get_local_id(1) % np) * warp_size + item_ct1.get_local_id(2) <
+                                      static_cast<uint32_t>(k_VKQ_sup) ?
+                        sycl::native::exp((float) (KQ_acc[(i0 / (np * warp_size)) * cpw + jc] - KQ_max[jc])) :
+                        0.0f;
+                KQ_sum_add += val;
+                tmp[i0/(np*warp_size)][jc1] = val;
+            }
+            KQ_sum[jc] = KQ_sum[jc]*KQ_max_scale + KQ_sum_add;
+
+#ifdef SYCL_FAST_FP16
+            const sycl::half2 KQ_max_scale_h2 = sycl::half2(KQ_max_scale, KQ_max_scale);
+#pragma unroll
+            for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
+                VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].x() *= KQ_max_scale_h2.x();
+                VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].y() *= KQ_max_scale_h2.y();
+            }
+#else
+#pragma unroll
+            for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
+                VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].x() *= KQ_max_scale;
+                VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].y() *= KQ_max_scale;
+            }
+#endif // SYCL_FAST_FP16
+        }
+
+#pragma unroll
+        for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {
+            const int i = i0 + (item_ct1.get_local_id(1) % np) * warp_size + item_ct1.get_local_id(2);
+
+            ggml_sycl_memcpy_1<sizeof(tmp[0])>(
+                KQ + (jc0 / KQ_cs + (item_ct1.get_local_id(1) / np) * (cpw / KQ_cs)) * (nbatch_fa * KQ_cs) + i * KQ_cs,
+                tmp[i0 / (np * warp_size)]);
+        }
+    }
+
+    // VKQ = V @ KQ matrix multiplication:
+    static_assert(DV <= DKQ, "bad DV");
+    static_assert(DV % nbatch_K == 0 || (nbatch_K % 3 == 0 && DV % (nbatch_K*2/3) == 0), "bad nbatch_K");
+    constexpr int nbatch_V = (DV % nbatch_K == 0 ? nbatch_K : nbatch_K*2/3) * nbatch_fa / DV; // Number of V columns that fit in SRAM for K.
+    static_assert(nbatch_fa % nbatch_V == 0, "bad nbatch_V");
+    static_assert(nbatch_V % np == 0, "bad nbatch_V");
+#pragma unroll
+    for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) {
+        flash_attn_tile_load_tile<warp_size, nwarps, nbatch_V, DV, 0, oob_check>
+            (V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0);
+        item_ct1.barrier();
+
+#ifdef SYCL_FAST_FP16
+#pragma unroll
+        for (int k1 = 0; k1 < nbatch_V; k1 += np) {
+            __dpct_align__(16) sycl::half2 V_k[(DVp / 2) / warp_size];
+            __dpct_align__(16) sycl::half2 KQ_k[cpw];
+
+            constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;
+#pragma unroll
+            for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
+                ggml_sycl_memcpy_1<cpy_ne_D * 4>(&V_k[i0 / warp_size],
+                                                 &KV_tmp[(k1 + item_ct1.get_local_id(1) % np) * (DV / 2) + i0 +
+                                                         item_ct1.get_local_id(2) * cpy_ne_D]);
+            }
+#pragma unroll
+            for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) {
+                const int jc_KQ = jc_VKQ_0 / KQ_cs + (item_ct1.get_local_id(1) / np) * (cpw / KQ_cs);
+
+                __dpct_align__(16) sycl::half tmp[KQ_cs];
+                ggml_sycl_memcpy_1<KQ_cs * sizeof(sycl::half)>(
+                    &tmp, KQ + jc_KQ * (nbatch_fa * KQ_cs) + (k0 + k1 + item_ct1.get_local_id(1) % np) * KQ_cs);
+#pragma unroll
+                for (int jc_VKQ_1 = 0; jc_VKQ_1 < KQ_cs; ++jc_VKQ_1) {
+                    KQ_k[jc_VKQ_0 + jc_VKQ_1] = sycl::half2(tmp[jc_VKQ_1]);
+                }
+            }
+
+#pragma unroll
+            for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
+#pragma unroll
+                for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) {
+                    VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].x() +=
+                        V_k[i0/warp_size].x()*KQ_k[jc_VKQ_0].x();
+                    VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].y() +=
+                        V_k[i0/warp_size].y()*KQ_k[jc_VKQ_0].y();
+                }
+            }
+        }
+#else
+#pragma unroll
+        for (int k1 = 0; k1 < nbatch_V; k1 += np) {
+            __dpct_align__(16) sycl::float2 V_k[(DVp/2)/warp_size];
+            __dpct_align__(16) float  KQ_k[cpw];
+
+            constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
+#pragma unroll
+            for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
+                ggml_sycl_memcpy_1<cpy_ne_D*4>(&V_k[i0/(2*warp_size)], &KV_tmp[(k1 + item_ct1.get_local_id(1) % np)*DV + i0 + item_ct1.get_local_id(2)*cpy_ne_D]);
+            }
+#pragma unroll
+            for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) {
+                const int jc_KQ = jc_VKQ_0/KQ_cs + (item_ct1.get_local_id(1) / np)*(cpw/KQ_cs);
+
+                ggml_sycl_memcpy_1<KQ_cs*sizeof(float)>(
+                    &KQ_k[jc_VKQ_0], KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + item_ct1.get_local_id(1) % np)*KQ_cs);
+            }
+
+#pragma unroll
+            for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
+#pragma unroll
+                for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) {
+                    VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].x() += V_k[i0/warp_size].x()*KQ_k[jc_VKQ_0];
+                    VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].y() += V_k[i0/warp_size].y()*KQ_k[jc_VKQ_0];
+                }
+            }
+        }
+#endif // SYCL_FAST_FP16
+        item_ct1.barrier();
+    }
+}
+
+template <int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, int warp_size>  // D == head size
+/*
+The total declared local variable size in device function flash_attn_tile exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure.
+*/
+static void flash_attn_tile(const char *  Q,
+                            const char *  K,
+                            const char *  V,
+                            const char *  mask,
+                            const char *  sinks,
+                            const int *  KV_max,
+                            float *  dst,
+                            sycl::float2 *  dst_meta,
+                            const float          scale,
+                            const float          max_bias,
+                            const float          m0,
+                            const float          m1,
+                            const uint32_t       n_head_log2,
+                            const float          logit_softcap,
+                            const int32_t        ne00,
+                            const sycl::uint3    ne01,
+                            const int32_t        ne02,
+                            const int32_t        ne03,
+                            const int32_t        nb01,
+                            const int32_t        nb02,
+                            const int32_t        nb03,
+                            const int32_t        ne10,
+                            const int32_t        ne11,
+                            const int32_t        ne12,
+                            const int32_t        ne13,
+                            const int32_t        nb11,
+                            const int32_t        nb12,
+                            const int64_t        nb13,
+                            const int32_t        nb21,
+                            const int32_t        nb22,
+                            const int64_t        nb23,
+                            const int32_t        ne31,
+                            const int32_t        ne32,
+                            const int32_t        ne33,
+                            const int32_t        nb31,
+                            const int32_t        nb32,
+                            const int64_t        nb33) {
+#ifdef SYCL_FLASH_ATTN
+    // Skip unused kernel variants for faster compilation:
+    auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    if ((use_logit_softcap && !(DV == 128 || DV == 256))) {
+        GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
+            max_bias, m0, m1, n_head_log2, logit_softcap,
+            ne00, ne01, ne02, ne03,
+                  nb01, nb02, nb03,
+            ne10, ne11, ne12, ne13,
+                  nb11, nb12, nb13,
+                  nb21, nb22, nb23,
+                  ne31, ne32, ne33,
+                  nb31, nb32, nb33);
+        return;
+    }
+
+    static_assert(ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols1*ncols2) != 0, "kernel config not defined");
+
+    constexpr int ncols     = ncols1*ncols2;
+
+    constexpr int nwarps    = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, ncols1*ncols2) / warp_size;
+    constexpr int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, ncols1*ncols2);
+    constexpr int nbatch_K  = ggml_sycl_fattn_tile_get_nbatch_K (DKQ, DV, ncols1*ncols2);
+
+    // In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+    const int col_Q_0 = item_ct1.get_group(2) * ncols1;  // Index of the first Q column for this SYCL block to work on.
+
+    const int           sequence  = item_ct1.get_group(0) / (ne02 / ncols2);
+    const int           head0     = item_ct1.get_group(0) * ncols2 - sequence * ne02;  // == item_ct1.get_group(0) % (ne02/ncols2)
+    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+    const float * Q_f  = (const float *) (Q + nb03*sequence + nb02* head0);
+    const sycl::half2 * K_h2      = (const sycl::half2 *) (K + nb13 * sequence + nb12 * (head0 / gqa_ratio));
+    const sycl::half2 * V_h2 =
+        (const sycl::half2 *) (V + nb23 * sequence + nb22 * (head0 / gqa_ratio));  // K and V have same shape
+
+    const sycl::half * maskh = mask ? (const sycl::half *) (mask + nb33 * (sequence % ne33)) : nullptr;
+
+    const int stride_K2   = nb11 / sizeof(sycl::half2);
+    const int stride_V2   = nb21 / sizeof(sycl::half2);
+    const int stride_mask = nb31 / sizeof(sycl::half);
+
+    const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
+
+    constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();
+    constexpr int cpy_ne = cpy_nb / 4;
+
+    constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp.
+    constexpr int np  = nwarps > ncols ? nwarps/ncols : 1; // Number of parallel warps per Q column.
+
+    static_assert(cpw == 1 || np == 1, "bad cpw / np");
+    static_assert(nbatch_fa % (np*warp_size) == 0, "nbatch_fa % (np*warp_size) != 0");
+
+    constexpr int DKQp = (DKQ + 2*warp_size - 1) & ~(2*warp_size - 1); // DKQ padded to multiple of 2*warp_size.
+    constexpr int DVp  = (DV  + 2*warp_size - 1) & ~(2*warp_size - 1); // DV  padded to multiple of 2*warp_size.
+
+    // Q_tmp == SRAM buffer to hold Q data for the entire lifetime of the kernel.
+    // KV_tmp == SRAM buffer to hold fragments of K/V data while iterating over ne11.
+    //     KV_tmp is padded to avoid memory conflicts for K (cpy_ne) and OOB accesses for V (DVp-DV).
+    // KQ == SRAM buffer to hold KQ fragments between KQ and VKQ matrix multiplications.
+    // VKQ == Accumulators in registers for the final VKQ result.
+
+
+#ifdef SYCL_FAST_FP16
+    constexpr size_t lsm_size1 = ncols * DKQ/2 ;
+    constexpr size_t lsm_size2 = nbatch_fa * (nbatch_K/2 + cpy_ne) + DVp-DV ;
+    constexpr size_t lsm_size3 = ncols * nbatch_fa;
+    constexpr size_t lsm_size4 = nwarps;
+
+    constexpr size_t local_share_mem_size = lsm_size1 * sizeof(sycl::half2) +
+                                            lsm_size2 * sizeof(sycl::half2) +
+                                            lsm_size3 * sizeof(sycl::half) +
+                                            lsm_size4 * sizeof(float);
+
+    syclex::work_group_static<char[local_share_mem_size]> lsm;
+
+    sycl::half2 *Q_tmp = (sycl::half2 *)&lsm;
+    sycl::half2 *KV_tmp = (sycl::half2*)(Q_tmp +lsm_size1);
+    sycl::half *KQ = (sycl::half *)(KV_tmp+lsm_size2);
+    float *KQ_max_new_shared = (float *)(KQ+lsm_size3);
+
+    __dpct_align__(16) sycl::half2 VKQ[cpw * ((DVp / 2) / warp_size)] = {
+        { 0.0f, 0.0f }
+    };
+#else
+    constexpr size_t lsm_size1 = ncols * DKQ ;
+    constexpr size_t lsm_size2 = nbatch_fa * (nbatch_K + cpy_ne) + DVp-DV;
+    constexpr size_t lsm_size3 = ncols * nbatch_fa;
+    constexpr size_t lsm_size4 = nwarps;
+
+    constexpr size_t local_share_mem_size = (lsm_size1 + lsm_size2 +lsm_size3 + lsm_size4) * sizeof(float);
+
+    syclex::work_group_static<char[local_share_mem_size]> lsm;
+
+    float *Q_tmp = (float *)&lsm;
+    float *KV_tmp = Q_tmp +lsm_size1;
+    float *KQ = KV_tmp+lsm_size2;
+    float *KQ_max_new_shared = KQ+lsm_size3;
+
+    __dpct_align__(16) sycl::float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
+
+
+#endif // SYCL_FAST_FP16
+
+    float KQ_max[cpw] = {};
+
+#pragma unroll
+    for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+        KQ_max[j0/nwarps] = -FLT_MAX/2.0f;
+    }
+    float KQ_sum[cpw] = {0.0f};
+
+    // Load Q data, convert to FP16 if fast:
+#pragma unroll
+    for (int jc0 = 0; jc0 < cpw; ++jc0) {
+        const int jc = jc0 + (item_ct1.get_local_id(1) / np) * cpw;
+
+        const int j = jc / ncols2;
+        const int c = jc % ncols2;
+
+        constexpr int cpy_ne_D = cpy_ne < DKQp/warp_size ? cpy_ne : DKQp/warp_size;
+
+#pragma unroll
+        for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) {
+            if (i0 + np * warp_size * cpy_ne_D <= DKQ ||
+                i0 + (item_ct1.get_local_id(1) % np) * (warp_size * cpy_ne_D) + item_ct1.get_local_id(2) * cpy_ne_D <
+                    DKQ) {
+                __dpct_align__(16) float tmp_f[cpy_ne_D] = { 0.0f };
+                ggml_sycl_memcpy_1<sizeof(tmp_f)>(
+                    tmp_f, &Q_f[c * (nb02 / sizeof(float)) + fastmodulo(col_Q_0 + j, ne01) * (nb01 / sizeof(float)) +
+                                i0 + (item_ct1.get_local_id(1) % np) * (warp_size * cpy_ne_D) +
+                                item_ct1.get_local_id(2) * cpy_ne_D]);
+
+#pragma unroll
+                for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
+                    tmp_f[i1] *= scale;
+                }
+
+#ifdef SYCL_FAST_FP16
+                __dpct_align__(16) sycl::half2 tmp_h2[cpy_ne_D / 2];
+#pragma unroll
+                for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) {
+                    tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]);
+#if defined(SYCL_FAST_FP16) && !defined(GGML_SYCL_F16)
+                    // Without the v_dot2_f32_f16 instruction there is a higher risk of numerical overflow in the KQ calculation.
+                    // Therefore, scale down Q values and apply the inverse scale the FP32 KQ values afterwards again.
+                    tmp_h2[i1 / 2] *= sycl::half2(0.25f, 0.25f);
+#endif // defined(SYCL_FAST_FP16) && !defined(GGML_SYCL_F16)
+                }
+                ggml_sycl_memcpy_1<sizeof(tmp_h2)>(
+                    &Q_tmp[jc * (DKQ / 2) + i0 / 2 + (item_ct1.get_local_id(1) % np) * (warp_size * cpy_ne_D / 2) +
+                           item_ct1.get_local_id(2) * (cpy_ne_D / 2)],
+                    tmp_h2);
+#else
+                ggml_sycl_memcpy_1<sizeof(tmp_f)>(
+                    &Q_tmp[jc* DKQ    + i0   + (item_ct1.get_local_id(1) % np)*(warp_size*cpy_ne_D)   + item_ct1.get_local_id(2)* cpy_ne_D],
+                    tmp_f);
+#endif // SYCL_FAST_FP16
+            }
+        }
+    }
+
+    item_ct1.barrier();
+
+    // Main loop over KV cache:
+    const int k_VKQ_max = KV_max ? KV_max[sequence * item_ct1.get_group_range(2) + item_ct1.get_group(2)] : ne11;
+    if (ncols2 == 1) {
+        // Branch with out-of-bounds checks.
+        int k_VKQ_0 = item_ct1.get_group(1) * nbatch_fa;
+        while (k_VKQ_0 < k_VKQ_max - nbatch_fa) {
+            constexpr bool oob_check = false;
+            flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap,
+                                 oob_check>(Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, stride_K2,
+                                            stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0,
+                                            KQ_max_new_shared);
+            k_VKQ_0 += item_ct1.get_group_range(1) * nbatch_fa;
+        }
+        if (k_VKQ_0 < k_VKQ_max) {
+            constexpr bool oob_check = true;
+            flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap,
+                                 oob_check>(Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, stride_K2,
+                                            stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0,
+                                            KQ_max_new_shared);
+        }
+    } else {
+        // Branch without out-of-bounds checks.
+        for (int k_VKQ_0 = item_ct1.get_group(1) * nbatch_fa; k_VKQ_0 < k_VKQ_max;
+             k_VKQ_0 += item_ct1.get_group_range(1) * nbatch_fa) {
+
+            constexpr bool oob_check = false;
+            flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap,
+                                 oob_check>(Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, stride_K2,
+                                            stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0,
+                                            KQ_max_new_shared);
+        }
+    }
+
+#pragma unroll
+    for (int jc0 = 0; jc0 < cpw; ++jc0) {
+        KQ_sum[jc0] = warp_reduce_sum<warp_size>(KQ_sum[jc0]);
+    }
+
+    if constexpr (np > 1) {
+        static_assert(cpw == 1, "bad cpw");
+        static_assert(nbatch_fa*nbatch_K >= nwarps*DVp, "KV_tmp too small");
+
+#ifdef SYCL_FAST_FP16
+        sycl::half2 * VKQ_combine = (sycl::half2 *) KV_tmp;
+#else
+        float * VKQ_combine    = (float *) KV_tmp;
+#endif // SYCL_FAST_FP16
+
+        float * KQ_sum_combine = (float *) Q_tmp;
+
+        if (item_ct1.get_local_id(1) % np != 0) {
+
+#ifdef SYCL_FAST_FP16
+            constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size;
+#pragma unroll
+            for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
+                ggml_sycl_memcpy_1<cpy_ne_D * 4>(
+                    &VKQ_combine[item_ct1.get_local_id(1) * (DVp / 2) + i0 + item_ct1.get_local_id(2) * cpy_ne_D],
+                    &VKQ[i0 / warp_size]);
+            }
+#else
+
+            constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
+
+#pragma unroll
+            for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
+                ggml_sycl_memcpy_1<cpy_ne_D*4>(
+                    &VKQ_combine[item_ct1.get_local_id(1)*DVp + i0 + item_ct1.get_local_id(2)*cpy_ne_D], ((const float *) VKQ) + i0/warp_size);
+            }
+#endif // SYCL_FAST_FP16
+
+            if (item_ct1.get_local_id(2) == 0) {
+                KQ_sum_combine[item_ct1.get_local_id(1)] = KQ_sum[0];
+            }
+            return;
+        }
+
+        item_ct1.barrier();
+
+#pragma unroll
+        for (int ip = 1; ip < np; ++ip) {
+#ifdef SYCL_FAST_FP16
+            constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size;
+#pragma unroll
+            for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
+                __dpct_align__(16) sycl::half2 tmp[cpy_ne_D];
+                ggml_sycl_memcpy_1<cpy_ne_D * 4>(tmp, &VKQ_combine[(item_ct1.get_local_id(1) + ip) * (DVp / 2) + i0 +
+                                                                   item_ct1.get_local_id(2) * cpy_ne_D]);
+#pragma unroll
+                for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
+                    VKQ[i0/warp_size + i1] += tmp[i1];
+                }
+            }
+#else
+            constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
+#pragma unroll
+            for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
+                __dpct_align__(16) float tmp[cpy_ne_D];
+                ggml_sycl_memcpy_1<cpy_ne_D*4>(tmp, &VKQ_combine[(item_ct1.get_local_id(1) + ip)*DVp + i0 + item_ct1.get_local_id(2)*cpy_ne_D]);
+#pragma unroll
+                for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
+                    ((float *)VKQ)[i0/warp_size + i1] += tmp[i1];
+                }
+            }
+#endif // SYCL_FAST_FP16
+
+            KQ_sum[0] += KQ_sum_combine[item_ct1.get_local_id(1) + ip];
+        }
+    }
+
+    // Attention sink: adjust KQ max and sum only for the first of all parallel blocks:
+    if (sinks && item_ct1.get_group(1) == 0) {
+#pragma unroll
+        for (int jc0 = 0; jc0 < cpw; ++jc0) {
+            const int   jc   = jc0 + (item_ct1.get_local_id(1) / np) * cpw;
+            const float sink = ((const float *) sinks)[head0 + jc % ncols2];
+
+            float       KQ_max_new_j = sycl::fmax((float) KQ_max[jc0], sink);
+            const float KQ_max_scale = sycl::native::exp((float) (KQ_max[jc0] - KQ_max_new_j));
+            KQ_max[jc0] = KQ_max_new_j;
+
+            const float val = sycl::native::exp((float) (sink - KQ_max[jc0]));
+            KQ_sum[jc0] = KQ_sum[jc0]*KQ_max_scale + val;
+
+#ifdef SYCL_FAST_FP16
+            const sycl::half2 KQ_max_scale_h2 = sycl::half2(KQ_max_scale, KQ_max_scale);
+#pragma unroll
+            for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
+                VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size] *= KQ_max_scale_h2;
+            }
+#else
+#pragma unroll
+            for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
+                VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].x() *= KQ_max_scale;
+                VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].y() *= KQ_max_scale;
+            }
+#endif // SYCL_FAST_FP16
+        }
+    }
+
+    // Write back results:
+#pragma unroll
+    for (int jc0 = 0; jc0 < cpw; ++jc0) {
+        const int jc = jc0 + (item_ct1.get_local_id(1) / np) * cpw;
+
+        const int j = jc / ncols2;
+        const int c = jc % ncols2;
+
+        if (ncols1 > 1 && col_Q_0 + j >= int(ne01.z())) {
+            return;
+        }
+
+        const float scale = item_ct1.get_group_range(1) == 1 ? 1.0f / KQ_sum[jc0] : 1.0f;
+
+        const int j_dst_unrolled =
+            ((sequence * int(ne01.z()) + col_Q_0 + j) * ne02 + head0 + c) * item_ct1.get_group_range(1) +
+            item_ct1.get_group(1);
+
+#ifdef SYCL_FAST_FP16
+        constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;
+#pragma unroll
+        for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
+            __dpct_align__(16) sycl::float2 tmp[cpy_ne_D];
+#pragma unroll
+            for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
+                tmp[i1] = VKQ[jc0 * ((DVp / 2) / warp_size) + i0 / warp_size + i1]
+                              .template convert<float, sycl::rounding_mode::automatic>();
+                tmp[i1].x() *= scale;
+                tmp[i1].y() *= scale;
+            }
+            if (i0 + warp_size * cpy_ne_D <= DV / 2 || i0 + item_ct1.get_local_id(2) * cpy_ne_D < DV / 2) {
+                ggml_sycl_memcpy_1<sizeof(tmp)>(
+                    &dst[j_dst_unrolled * DV + 2 * i0 + item_ct1.get_local_id(2) * (2 * cpy_ne_D)], tmp);
+            }
+        }
+#else
+        constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
+#pragma unroll
+        for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
+            if (i0 + warp_size*cpy_ne_D <= DV || i0 + item_ct1.get_local_id(2)*cpy_ne_D < DV) {
+#pragma unroll
+                for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) {
+                    VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].x() *= scale;
+                    VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].y() *= scale;
+                }
+                ggml_sycl_memcpy_1<cpy_ne_D*4>(
+                    &dst[j_dst_unrolled*DV + i0 + item_ct1.get_local_id(2)*cpy_ne_D],
+                    &VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size)]);
+            }
+        }
+#endif // SYCL_FAST_FP16
+
+        if (item_ct1.get_group_range(1) != 1 && item_ct1.get_local_id(2) == 0) {
+            dst_meta[j_dst_unrolled] = make_float2(KQ_max[jc0], KQ_sum[jc0]);
+        }
+    }
+#else
+    GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
+        max_bias, m0, m1, n_head_log2, logit_softcap,
+        ne00, ne01, ne02, ne03,
+              nb01, nb02, nb03,
+        ne10, ne11, ne12, ne13,
+              nb11, nb12, nb13,
+              nb21, nb22, nb23,
+              ne31, ne32, ne33,
+              nb31, nb32, nb33);
+#endif // SYCL_FLASH_ATTN
+}
+
+template <int DKQ, int DV, int ncols2, bool use_logit_softcap>
+static void launch_fattn_tile_switch_ncols1(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * Q = dst->src[0];
+
+    const int id        = ggml_sycl_get_device();
+    const int cc        = ggml_sycl_info().devices[id].cc;
+    const int warp_size = WARP_32_SIZE; //can't support WARP_16_SIZE
+
+    constexpr size_t nbytes_shared = 0;
+
+    if constexpr (DV <= 256) {
+        if (Q->ne[1] > 16/ncols2) {
+            constexpr int cols_per_block = 32;
+            const int nwarps    = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
+            const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
+            launch_fattn<DV, cols_per_block/ncols2, ncols2,
+                flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
+                (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
+            return;
+        }
+    }
+
+    if (Q->ne[1] > 8/ncols2) {
+        constexpr int cols_per_block = 16;
+        const int nwarps    = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
+        const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
+        launch_fattn<DV, cols_per_block/ncols2, ncols2,
+            flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
+            (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
+        return;
+    }
+
+    if constexpr (ncols2 <= 8) {
+        if (Q->ne[1] > 4/ncols2) {
+            constexpr int cols_per_block = 8;
+            const int nwarps    = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
+            const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
+            launch_fattn<DV, cols_per_block/ncols2, ncols2,
+                flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
+                (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
+            return;
+        }
+    }
+
+    if constexpr (ncols2 <= 4) {
+        if (Q->ne[1] > 2/ncols2) {
+            constexpr int cols_per_block = 4;
+            const int nwarps    = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
+            const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
+            launch_fattn<DV, cols_per_block/ncols2, ncols2,
+                flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
+                (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
+            return;
+        }
+    }
+
+    if constexpr (ncols2 <= 2) {
+        constexpr int cols_per_block = 2;
+        const int nwarps    = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
+        const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
+        launch_fattn<DV, cols_per_block/ncols2, ncols2,
+            flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
+            (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
+        return;
+    }
+
+    GGML_ABORT("fatal error");
+}
+
+template <int DKQ, int DV, bool use_logit_softcap>
+static void launch_fattn_tile_switch_ncols2(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * KQV  = dst;
+    const ggml_tensor * Q    = dst->src[0];
+    const ggml_tensor * K    = dst->src[1];
+    const ggml_tensor * mask = dst->src[3];
+
+    float max_bias = 0.0f;
+    memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
+
+    GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
+    const int gqa_ratio = Q->ne[2] / K->ne[2];
+
+    // On NVIDIA (Pascal and older) the GQA optimizations seem to be detrimental in some cases.
+    // However, for DKQ == 576, DV == 512 only the kernel variant with GQA optimizations is implemented.
+    //const bool nvidia = GGML_SYCL_CC_IS_NVIDIA(ggml_sycl_info().devices[ggml_sycl_get_device()].cc);
+    const int gqa_limit = gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX;
+    const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0;
+
+    if constexpr (DV == 512) {
+        if (use_gqa_opt && gqa_ratio % 16 == 0) {
+            launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
+            return;
+        }
+        if (use_gqa_opt && gqa_ratio % 4 == 0) {
+            launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
+            return;
+        }
+    }
+
+    if constexpr (DV <= 256) {
+        if (use_gqa_opt && gqa_ratio % 8 == 0) {
+            launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
+            return;
+        }
+
+        if (use_gqa_opt && gqa_ratio % 4 == 0) {
+            launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
+            return;
+        }
+
+        if (use_gqa_opt && gqa_ratio % 2 == 0) {
+            launch_fattn_tile_switch_ncols1<DKQ, DV, 2, use_logit_softcap>(ctx, dst);
+            return;
+        }
+
+        launch_fattn_tile_switch_ncols1<DKQ, DV, 1, use_logit_softcap>(ctx, dst);
+        return;
+    }
+    GGML_ABORT("fatal error");
+}
+
+template <int DKQ, int DV>
+void ggml_sycl_flash_attn_ext_tile_case(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * KQV = dst;
+
+    float logit_softcap;
+    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
+    if (logit_softcap == 0.0f) {
+        constexpr bool use_logit_softcap = false;
+        launch_fattn_tile_switch_ncols2<DKQ, DV, use_logit_softcap>(ctx, dst);
+    } else {
+        constexpr bool use_logit_softcap = true;
+        launch_fattn_tile_switch_ncols2<DKQ, DV, use_logit_softcap>(ctx, dst);
+    }
+}
+
+void ggml_sycl_flash_attn_ext_tile(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+
+#define DECL_FATTN_TILE_CASE(DKQ, DV)                             \
+    template void ggml_sycl_flash_attn_ext_tile_case              \
+    <DKQ, DV>(ggml_backend_sycl_context & ctx, ggml_tensor * dst) \
+
+extern DECL_FATTN_TILE_CASE( 40,  40);
+extern DECL_FATTN_TILE_CASE( 64,  64);
+extern DECL_FATTN_TILE_CASE( 72,  72);
+extern DECL_FATTN_TILE_CASE( 80,  80);
+extern DECL_FATTN_TILE_CASE( 96,  96);
+extern DECL_FATTN_TILE_CASE(112, 112);
+extern DECL_FATTN_TILE_CASE(128, 128);
+extern DECL_FATTN_TILE_CASE(256, 256);
+extern DECL_FATTN_TILE_CASE(576, 512);
+
diff --git a/ggml/src/ggml-sycl/fattn-vec.hpp b/ggml/src/ggml-sycl/fattn-vec.hpp
new file mode 100644 (file)
index 0000000..48c3890
--- /dev/null
@@ -0,0 +1,667 @@
+#ifndef GGML_SYCL_FATTN_VEC_HPP
+#define GGML_SYCL_FATTN_VEC_HPP
+
+#include <sycl/sycl.hpp>
+#include <sycl/ext/oneapi/work_group_static.hpp>
+#include <iostream>
+#include <iomanip>
+
+#include "dpct/helper.hpp"
+#include "common.hpp"
+#include "ggml.h"
+#include "fattn-common.hpp"
+#include <cmath>
+#include <float.h>
+
+namespace syclex = sycl::ext::oneapi::experimental;
+
+static int ggml_sycl_fattn_vec_get_nthreads_host(const int cc) {
+    return 128;
+    GGML_UNUSED(cc);
+}
+
+static constexpr int ggml_sycl_fattn_vec_get_nthreads_device() {
+    return 128;
+}
+
+// Currenlty llvm with the amdgcn target dose not support unrolling loops
+// that contain a break that can not be resolved at compile time.
+#ifdef __clang__
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wpass-failed"
+#endif // __clang__
+
+template <int D,
+          int ncols,
+          int type_K,
+          int type_V,
+          bool use_logit_softcap,
+          int warp_size>  // D == head size
+static void flash_attn_ext_vec(const char* __restrict__ Q,
+                        const char* __restrict__ K,
+                        const char* __restrict__ V,
+                        const char* __restrict__ mask,
+                        const char* __restrict__ sinks,
+                        const int* __restrict__ KV_max,
+                        float* __restrict__ dst,
+                        sycl::float2* __restrict__ dst_meta,
+                        const float scale,
+                        const float max_bias,
+                        const float m0,
+                        const float m1,
+                        const uint32_t n_head_log2,
+                        const float logit_softcap,
+                        const int32_t ne00,
+                        const sycl::uint3 ne01,
+                        const int32_t ne02,
+                        const int32_t ne03,
+                        const int32_t nb01,
+                        const int32_t nb02,
+                        const int32_t nb03,
+                        const int32_t ne10,
+                        const int32_t ne11,
+                        const int32_t ne12,
+                        const int32_t ne13,
+                        const int32_t nb11,
+                        const int32_t nb12,
+                        const int64_t nb13,
+                        const int32_t nb21,
+                        const int32_t nb22,
+                        const int64_t nb23,
+                        const int32_t ne31,
+                        const int32_t ne32,
+                        const int32_t ne33,
+                        const int32_t nb31,
+                        const int32_t nb32,
+                        const int64_t nb33) {
+#ifdef SYCL_FLASH_ATTN
+    // Skip unused kernel variants for faster compilation:
+
+    auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    if (use_logit_softcap && !(D == 128 || D == 256)) {
+        GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
+            max_bias, m0, m1, n_head_log2, logit_softcap,
+            ne00, ne01, ne02, ne03,
+                  nb01, nb02, nb03,
+            ne10, ne11, ne12, ne13,
+                  nb11, nb12, nb13,
+                  nb21, nb22, nb23,
+                  ne31, ne32, ne33,
+                  nb31, nb32, nb33);
+        return;
+    }
+
+    //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+    constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();
+    constexpr int cpy_ne = cpy_nb / 4;
+
+    constexpr int nthreads_KQ_q = (D/4 < warp_size ? D/4 : warp_size);
+    constexpr int nthreads_V_q  = (D/4 < warp_size ? D/4 : warp_size);
+
+    constexpr int nthreads    = ggml_sycl_fattn_vec_get_nthreads_device();
+    constexpr int nthreads_KQ = type_K == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_KQ_q;
+    constexpr int nthreads_V  = type_V == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_V_q;
+
+    static_assert(warp_size % nthreads_KQ == 0, "bad nthreads_K");
+    static_assert(warp_size % nthreads_V  == 0, "bad nthreads_V");
+
+    constexpr int V_rows_per_thread = type_V == GGML_TYPE_F16 ? 2*cpy_ne : 4;
+    constexpr int V_cols_per_iter   = warp_size / nthreads_V;
+
+    constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ, warp_size>();
+    constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
+#ifdef GGML_SYCL_F16
+    constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, sycl::half, V_rows_per_thread>();
+#else
+    constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, float, V_rows_per_thread>();
+#endif // GGML_SYCL_F16
+
+    const int ic0 = item_ct1.get_group(2) * ncols;  // Index of the Q/QKV column to work on.
+
+    const int sequence  = item_ct1.get_group(0) / ne02;
+    const int head      = item_ct1.get_group(0) - sequence * ne02;
+    const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+    Q += nb03*sequence + nb02* head              + nb01*ic0;
+    K += nb13*sequence + nb12*(head / gqa_ratio);
+    V += nb23*sequence + nb22*(head / gqa_ratio);
+
+    const sycl::half * maskh = (const sycl::half *) (mask + nb33 * (sequence % ne33) + nb31 * ic0);
+
+    const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
+
+    static_assert(D % (2*warp_size) == 0, "D not divisible by 2*warp_size == 64.");
+    constexpr int nwarps = nthreads / warp_size;
+    const int     tid    = warp_size * item_ct1.get_local_id(1) + item_ct1.get_local_id(2);
+    __builtin_assume(tid < nthreads);
+
+    constexpr int ne_KQ      = ncols*D;
+    constexpr int ne_combine = nwarps*V_cols_per_iter*D;
+
+    constexpr size_t lsm_size1 = ncols * warp_size;
+    constexpr size_t lsm_size2 = ncols * warp_size;
+#ifdef GGML_SYCL_F16
+    sycl::half2 VKQ[ncols][(D / 2) / nthreads_V] = { { { 0.0f, 0.0f } } };
+    constexpr size_t lsm_size3 = (ne_KQ > ne_combine ? ne_KQ : ne_combine);
+    constexpr size_t local_share_mem_size = (lsm_size1 + lsm_size2)*sizeof(float) + lsm_size3*sizeof(sycl::half);
+
+    syclex::work_group_static<char[local_share_mem_size]> lsm;
+
+    float *KQ_max_shared = (float *)&lsm;
+    float *KQ_sum_shared = KQ_max_shared+lsm_size1;
+    sycl::half* KQ = (sycl::half*)(KQ_sum_shared + lsm_size2);
+
+
+#else
+    sycl::float2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
+
+    constexpr size_t lsm_size3 = (ne_KQ > ne_combine ? ne_KQ : ne_combine);
+    constexpr size_t local_share_mem_size = (lsm_size1 + lsm_size2 + lsm_size3)*sizeof(float);
+
+
+    syclex::work_group_static<char[local_share_mem_size]> lsm;
+    float *KQ_max_shared = (float *)&lsm;
+    float *KQ_sum_shared = KQ_max_shared+lsm_size1;
+    float* KQ = KQ_sum_shared + lsm_size2;
+
+#endif // GGML_SYCL_F16
+
+    float KQ_max[ncols];
+    float KQ_sum[ncols];
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        KQ_max[j] = -FLT_MAX/2.0f;
+        KQ_sum[j] = 0.0f;
+    }
+
+    // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
+#ifdef GGML_SYCL_F16
+    sycl::half2 Q_reg[ncols][(D / 2) / nthreads_KQ] = {{{0.0f, 0.0f}}};  // Will be initialized completely.
+#else
+    sycl::float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized.
+#endif // GGML_SYCL_F16
+    int    Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
+    sycl::float2 Q_ds[ncols][1 > D / (sizeof(int) * nthreads_KQ) ? 1 : D / (sizeof(int) * nthreads_KQ)];
+    if constexpr (Q_q8_1) {
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+            const int j = j0 + item_ct1.get_local_id(1);
+
+            if (j0 + nwarps > ncols && j >= ncols) {
+                break;
+            }
+
+            // Reuse KQ as temporary storage for converting Q to q8_1:
+            int    * tmp_q_i32 = (int    *) &KQ[j*D];
+            sycl::float2 * tmp_q_ds  = (sycl::float2 *) (tmp_q_i32 + D / sizeof(int));
+
+            // Set memory to zero if out of bounds:
+            if (ncols > 1 && ic0 + j >= int(ne01.z())) {
+#pragma unroll
+                for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += warp_size) {
+                    const int i = i0 + item_ct1.get_local_id(2);
+
+                    if (i0 + warp_size <= int(D/sizeof(int)) || i < int(D/sizeof(int))) {
+                        tmp_q_i32[i] = 0;
+                    }
+                }
+                if (item_ct1.get_local_id(2) < D/QK8_1) {
+                    tmp_q_ds[item_ct1.get_local_id(2)] = sycl::float2(0.0f, 0.0f);
+                }
+            } else {
+                const float * Q_f = (const float *) (Q + j*nb01);
+                constexpr int nthreads_quantize = D/sizeof(int) < warp_size ? D/sizeof(int) : warp_size;
+#pragma unroll
+                for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_quantize) {
+                    quantize_q8_1_to_shared<sycl::float2, nthreads_quantize, warp_size>
+                        (Q_f + i0*sizeof(int), scale, tmp_q_i32 + i0, tmp_q_ds + i0/QI8_1);
+                }
+            }
+        }
+
+
+        item_ct1.barrier(sycl::access::fence_space::local_space);
+
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            int    * tmp_q_i32 = (int    *) &KQ[j*D];
+            sycl::float2 * tmp_q_ds  = (sycl::float2 *) (tmp_q_i32 + D / sizeof(int));
+
+#pragma unroll
+            for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_KQ) {
+                const int i =
+                    i0 + (nthreads_KQ == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_KQ);
+
+                Q_i32[j][i0/nthreads_KQ] = tmp_q_i32[i];
+                Q_ds[j][i0/nthreads_KQ]  = tmp_q_ds[i/QI8_1];
+            }
+        }
+
+        item_ct1.barrier(sycl::access::fence_space::local_space);
+
+    } else {
+#ifdef GGML_SYCL_F16
+        const sycl::half2 scale_h2 = sycl::half2(scale, scale);
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            const sycl::float2 * Q_j = (const sycl::float2 *) (Q + j * nb01);
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {
+                const int i = i0 + (nthreads_KQ == warp_size ? item_ct1.get_local_id(2) :
+                                                               item_ct1.get_local_id(2) % nthreads_KQ) *
+                                       cpy_ne;
+
+                sycl::float2 tmp[cpy_ne] = {
+                    { 0.0f, 0.0f }
+                };
+                if (ncols == 1 || ic0 + j < int(ne01.z())) {
+                    ggml_sycl_memcpy_1<cpy_nb>(tmp,            &Q_j[i]);
+                    ggml_sycl_memcpy_1<cpy_nb>(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]);
+                }
+#pragma unroll
+                for (int i1 = 0; i1 < cpy_ne; ++i1) {
+                    Q_reg[j][i0 / nthreads_KQ + i1] = sycl::half2(tmp[i1].x(), tmp[i1].y());
+                }
+            }
+#pragma unroll
+            for (int k = 0; k < (D/2)/nthreads_KQ; ++k) {
+                Q_reg[j][k] *= scale_h2;
+            }
+        }
+#else
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            const sycl::float2 * Q_j = (const sycl::float2 *) (Q + j*nb01);
+#pragma unroll
+            for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {
+                const int i = i0 + (nthreads_KQ == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_KQ)*cpy_ne;
+                if (ncols == 1 || ic0 + j < int(ne01.z())) {
+                    ggml_sycl_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ],            &Q_j[i]);
+                    ggml_sycl_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ + cpy_ne/2], &Q_j[i + cpy_ne/2]);
+                }
+            }
+#pragma unroll
+            for (int k = 0; k < (D/2)/nthreads_KQ; ++k) {
+                Q_reg[j][k].x() *= scale;
+                Q_reg[j][k].y() *= scale;
+            }
+        }
+#endif // GGML_SYCL_F16
+    }
+
+    const int k_VKQ_max = KV_max ? KV_max[sequence * item_ct1.get_group_range(2) + item_ct1.get_group(2)] : ne11;
+    K += item_ct1.get_group(1) * nthreads * nb11;
+    V += item_ct1.get_group(1) * nthreads * nb21;
+    maskh += item_ct1.get_group(1) * nthreads;
+    for (int k_VKQ_0 = item_ct1.get_group(1) * nthreads; k_VKQ_0 < k_VKQ_max;
+         k_VKQ_0 += item_ct1.get_group_range(1) * nthreads,
+             // Increment pointers after each loop:
+         K += item_ct1.get_group_range(1) * nthreads * nb11, V += item_ct1.get_group_range(1) * nthreads * nb21,
+             maskh += item_ct1.get_group_range(1) * nthreads) {
+        // Calculate KQ tile and keep track of new maximum KQ values:
+        float KQ_reg[ncols]={}; // KQ in registers.
+        float KQ_max_new[ncols]={};
+
+
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            KQ_max_new[j] = KQ_max[j];
+        }
+
+#pragma unroll
+        for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; ++i_KQ_0) {
+            const int i_KQ = item_ct1.get_local_id(1) * warp_size +
+                             (nthreads_KQ == warp_size ? 0 : (item_ct1.get_local_id(2) & ~(nthreads_KQ - 1))) + i_KQ_0;
+
+#pragma unroll
+            for (int j = 0; j < ncols; ++j) {
+                float sum = vec_dot_KQ(K + i_KQ*nb11, Q_reg[j], Q_i32[j], Q_ds[j]);
+                sum = warp_reduce_sum<nthreads_KQ>(sum);
+
+                if (use_logit_softcap) {
+                    sum = logit_softcap * sycl::tanh(sum);
+                }
+                if (mask) {
+                    sum += slope * sycl::vec<sycl::half, 1>(maskh[j * ne11 + i_KQ])
+                                       .convert<float, sycl::rounding_mode::automatic>()[0];
+                }
+
+                KQ_max_new[j] = sycl::fmax((float) KQ_max_new[j], sum);
+
+                if (int(nthreads_KQ == warp_size ? item_ct1.get_local_id(2)
+                                                 : item_ct1.get_local_id(2) %
+                                                       nthreads_KQ) == i_KQ_0) {
+                  KQ_reg[j] = sum;
+                }
+            }
+        }
+
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+#pragma unroll
+            for (int offset = nthreads_KQ; offset < warp_size; offset <<= 1) {
+               KQ_max_new[j] = sycl::fmax(
+                  (float)KQ_max_new[j],
+                  (float)dpct::permute_sub_group_by_xor(
+                      sycl::ext::oneapi::this_work_item::get_sub_group(),
+                      KQ_max_new[j],
+                      offset,
+                      warp_size));
+            }
+            const float KQ_max_scale = sycl::native::exp((float) (KQ_max[j] - KQ_max_new[j]));
+            KQ_max[j] = KQ_max_new[j];
+
+            KQ_reg[j]            = sycl::native::exp((float) (KQ_reg[j] - KQ_max[j]));
+            KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j];
+            KQ[j*nthreads + tid] = KQ_reg[j];
+
+#ifdef GGML_SYCL_F16
+            const sycl::half2 KQ_max_scale_h2 = sycl::half2(KQ_max_scale, KQ_max_scale);
+#pragma unroll
+            for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
+                VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2;
+            }
+#else
+#pragma unroll
+            for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
+                VKQ[j][i_VKQ_0/nthreads_V].x() *= KQ_max_scale;
+                VKQ[j][i_VKQ_0/nthreads_V].y() *= KQ_max_scale;
+            }
+#endif // GGML_SYCL_F16
+        }
+
+        sycl::group_barrier(sycl::ext::oneapi::this_work_item::get_sub_group());
+
+#pragma unroll
+        for (int k0 = 0; k0 < warp_size; k0 += V_cols_per_iter) {
+            const int k = item_ct1.get_local_id(1) * warp_size + k0 +
+                          (nthreads_V == warp_size ? 0 : item_ct1.get_local_id(2) / nthreads_V);
+
+#ifdef GGML_SYCL_F16
+            sycl::half2 KQ_k[ncols];
+#pragma unroll
+            for (int j = 0; j < ncols; ++j) {
+                KQ_k[j] = sycl::half2(KQ[j * nthreads + k]);
+            }
+#pragma unroll
+            for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
+                sycl::half2 tmp[V_rows_per_thread / 2];
+                dequantize_V(V + k * nb21, tmp,
+                             2 * i_VKQ_0 + (nthreads_V == warp_size ? item_ct1.get_local_id(2) :
+                                                                      item_ct1.get_local_id(2) % nthreads_V) *
+                                               V_rows_per_thread);
+#pragma unroll
+                for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
+#pragma unroll
+                    for (int j = 0; j < ncols; ++j) {
+                        VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1] += tmp[i_VKQ_1]*KQ_k[j];
+                    }
+                }
+            }
+#else
+            float KQ_k[ncols];
+#pragma unroll
+            for (int j = 0; j < ncols; ++j) {
+                KQ_k[j] = KQ[j*nthreads + k];
+            }
+#pragma unroll
+            for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
+                sycl::float2 tmp[V_rows_per_thread/2];
+                dequantize_V(V + k*nb21, tmp,
+                    2*i_VKQ_0 + (nthreads_V == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_V)*V_rows_per_thread);
+#pragma unroll
+                for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
+#pragma unroll
+                    for (int j = 0; j < ncols; ++j) {
+                        VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].x() += tmp[i_VKQ_1].x()*KQ_k[j];
+                        VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].y() += tmp[i_VKQ_1].y()*KQ_k[j];
+                    }
+                }
+            }
+#endif // GGML_SYCL_F16
+        }
+    }
+
+    if (sinks && item_ct1.get_group(1) == 0) {
+        const float sink = ((const float *) sinks)[head];
+
+#pragma unroll
+        for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+            const int j = j0 + item_ct1.get_local_id(1);
+
+            if (j0 + nwarps > ncols && j >= ncols) {
+                break;
+            }
+            const float kqmax_new_j  = sycl::fmax(sink, (float) KQ_max[j]);
+            const float KQ_max_scale = sycl::native::exp((float) (KQ_max[j] - kqmax_new_j));
+            KQ_max[j] = kqmax_new_j;
+
+            KQ_sum[j] = KQ_sum[j] * KQ_max_scale +
+                        (item_ct1.get_local_id(2) == 0 ? sycl::native::exp((float) (sink - KQ_max[j])) : 0.0f);
+#ifdef GGML_SYCL_F16
+            const sycl::half2 KQ_max_scale_h2 = sycl::half2(KQ_max_scale, KQ_max_scale);
+#pragma unroll
+            for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
+                VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2;
+            }
+#else
+#pragma unroll
+            for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
+                VKQ[j][i_VKQ_0/nthreads_V].x() *= KQ_max_scale;
+                VKQ[j][i_VKQ_0/nthreads_V].y() *= KQ_max_scale;
+            }
+#endif // GGML_SYCL_F16
+        }
+    }
+
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        if (item_ct1.get_local_id(1) == 0) {
+            KQ_max_shared[j*warp_size+item_ct1.get_local_id(2)] = -FLT_MAX / 2.0f;
+            KQ_sum_shared[j*warp_size+item_ct1.get_local_id(2)] = 0.0f;
+        }
+    }
+
+    item_ct1.barrier(sycl::access::fence_space::local_space);
+
+#pragma unroll
+    for (int j = 0; j < ncols; ++j) {
+        if (item_ct1.get_local_id(2) == 0) {
+            KQ_max_shared[j*warp_size+item_ct1.get_local_id(1)] = KQ_max[j];
+        }
+    }
+
+
+    item_ct1.barrier(sycl::access::fence_space::local_space);
+
+#pragma unroll
+    for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
+        if (ncols > 1 && ic0 + j_VKQ >= int(ne01.z())) {
+            break;
+        }
+
+        float kqmax_new         = KQ_max_shared[j_VKQ*warp_size+item_ct1.get_local_id(2)];
+        kqmax_new = warp_reduce_max<warp_size>(kqmax_new);
+        const float kqmax_scale = sycl::native::exp((float) (KQ_max[j_VKQ] - kqmax_new));
+        KQ_max[j_VKQ] = kqmax_new;
+
+#ifdef GGML_SYCL_F16
+        sycl::half2 * VKQ_tmp = (sycl::half2 *) KQ + item_ct1.get_local_id(1) * (V_cols_per_iter * D / 2) +
+                                (nthreads_V == warp_size ? 0 : item_ct1.get_local_id(2) / nthreads_V) * (D / 2);
+
+        const sycl::half2 kqmax_scale_h2 = sycl::half2(kqmax_scale, kqmax_scale);
+#pragma unroll
+        for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
+            VKQ[j_VKQ][i_VKQ_0/nthreads_V] *= kqmax_scale_h2;
+        }
+#pragma unroll
+        for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
+            const int i_VKQ =
+                i_VKQ_0 + (nthreads_V == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_V) *
+                              (V_rows_per_thread / 2);
+
+            ggml_sycl_memcpy_1<V_rows_per_thread * sizeof(sycl::half)>(VKQ_tmp + i_VKQ,
+                                                                       &VKQ[j_VKQ][i_VKQ_0 / nthreads_V]);
+        }
+#else
+        sycl::float2 * VKQ_tmp = (sycl::float2 *) KQ + item_ct1.get_local_id(1)*(V_cols_per_iter*D/2)
+            + (nthreads_V == warp_size ? 0 : item_ct1.get_local_id(2) / nthreads_V)*(D/2);
+#pragma unroll
+        for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
+            VKQ[j_VKQ][i_VKQ_0/nthreads_V].x() *= kqmax_scale;
+            VKQ[j_VKQ][i_VKQ_0/nthreads_V].y() *= kqmax_scale;
+        }
+#pragma unroll
+        for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
+            const int i_VKQ = i_VKQ_0 + (nthreads_V == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_V)*(V_rows_per_thread/2);
+
+            ggml_sycl_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ,                       &VKQ[j_VKQ][i_VKQ_0/nthreads_V]);
+            ggml_sycl_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ + V_rows_per_thread/4, &VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]);
+        }
+#endif // GGML_SYCL_F16
+
+        KQ_sum[j_VKQ] *= kqmax_scale;
+        KQ_sum[j_VKQ] = warp_reduce_sum<warp_size>(KQ_sum[j_VKQ]);
+        if (item_ct1.get_local_id(2) == 0) {
+            KQ_sum_shared[j_VKQ*warp_size+item_ct1.get_local_id(1)] = KQ_sum[j_VKQ];
+        }
+
+        item_ct1.barrier(sycl::access::fence_space::local_space);
+
+
+        if (nthreads <= D || tid < D) {
+            KQ_sum[j_VKQ] = KQ_sum_shared[j_VKQ*warp_size+item_ct1.get_local_id(2)];
+            KQ_sum[j_VKQ] = warp_reduce_sum<warp_size>(KQ_sum[j_VKQ]);
+
+#pragma unroll
+            for (int i0 = 0; i0 < D; i0 += nthreads) {
+                float dst_val = 0;
+#pragma unroll
+                for (int w = 0; w < nwarps; ++w) {
+#pragma unroll
+                    for (int v = 0; v < V_cols_per_iter; ++v) {
+                        dst_val += float(KQ[w*V_cols_per_iter*D + v*D + i0 + tid]);
+                    }
+                }
+                if (item_ct1.get_group_range(1) == 1) {
+                    dst_val /= KQ_sum[j_VKQ];
+                }
+                dst[(((sequence * int(ne01.z()) + ic0 + j_VKQ) * ne02 + head) * item_ct1.get_group_range(1) +
+                     item_ct1.get_group(1)) *
+                        D +
+                    i0 + tid] = dst_val;
+            }
+        }
+
+        if (j_VKQ < ncols-1) {
+            item_ct1.barrier(sycl::access::fence_space::local_space);
+        }
+
+    }
+
+    if (item_ct1.get_group_range(1) != 1 && tid < ncols && (ncols == 1 || ic0 + tid < int(ne01.z()))) {
+        dst_meta[((sequence * int(ne01.z()) + ic0 + tid) * ne02 + head) * item_ct1.get_group_range(1) +
+                 item_ct1.get_group(1)] = make_float2(KQ_max[tid], KQ_sum[tid]);
+    }
+#else
+    GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
+        max_bias, m0, m1, n_head_log2, logit_softcap,
+        ne00, ne01, ne02, ne03,
+              nb01, nb02, nb03,
+        ne10, ne11, ne12, ne13,
+              nb11, nb12, nb13,
+              nb21, nb22, nb23,
+              ne31, ne32, ne33,
+              nb31, nb32, nb33);
+
+#endif // SYCL_FLASH_ATTN
+}
+#ifdef __clang__
+#pragma clang diagnostic pop
+#endif // __clang__
+
+
+template <int D, int cols_per_block, int type_K, int type_V, bool use_logit_softcap>
+void ggml_sycl_flash_attn_ext_vec_case_impl(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+
+    const int warp_size = WARP_16_SIZE; //better performance than WARP_32_SIZE
+
+    const int cc = ggml_sycl_info().devices[ggml_sycl_get_device()].cc;
+
+    const int nthreads = ggml_sycl_fattn_vec_get_nthreads_host(cc);
+    const int nwarps   = nthreads / warp_size;
+
+    const bool need_f16_K = type_K == GGML_TYPE_F16;
+    const bool need_f16_V = type_V == GGML_TYPE_F16;
+    constexpr size_t nbytes_shared = 0;
+
+    launch_fattn<D, cols_per_block, 1,
+                 flash_attn_ext_vec<D, cols_per_block, type_K, type_V,
+                                    use_logit_softcap, warp_size>, warp_size>(
+        ctx, dst, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
+}
+
+template <int D, int type_K, int type_V>
+void ggml_sycl_flash_attn_ext_vec_case(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * KQV = dst;
+    const ggml_tensor * Q   = dst->src[0];
+
+    float logit_softcap;
+    memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
+    if (Q->ne[1] == 1) {
+        constexpr int cols_per_block = 1;
+        if (logit_softcap == 0.0f) {
+            constexpr bool use_logit_softcap = false;
+            ggml_sycl_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
+        } else {
+            constexpr bool use_logit_softcap = true;
+            ggml_sycl_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
+        }
+        return;
+    }
+
+    constexpr int cols_per_block = 2;
+    if (logit_softcap == 0.0f) {
+        constexpr bool use_logit_softcap = false;
+        ggml_sycl_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
+    } else {
+        constexpr bool use_logit_softcap = true;
+        ggml_sycl_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
+    }
+}
+
+#define DECL_FATTN_VEC_CASE(D, type_K, type_V)                              \
+    template void ggml_sycl_flash_attn_ext_vec_case                         \
+    <D, type_K, type_V>(ggml_backend_sycl_context & ctx, ggml_tensor * dst) \
+
+#define EXTERN_DECL_FATTN_VEC_CASES(D, type_K)             \
+    extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_F16);  \
+    extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_0); \
+    extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_1); \
+    extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \
+    extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \
+    extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \
+
+EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16)
+EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0)
+EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_1)
+EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0)
+EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1)
+EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0)
+
+EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16)
+EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0)
+EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_1)
+EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0)
+EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1)
+EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0)
+
+EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16)
+EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0)
+EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_1)
+EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0)
+EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1)
+EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0)
+
+#endif // GGML_SYCL_FATTN_VEC_HPP
diff --git a/ggml/src/ggml-sycl/fattn.cpp b/ggml/src/ggml-sycl/fattn.cpp
new file mode 100644 (file)
index 0000000..c276ed8
--- /dev/null
@@ -0,0 +1,225 @@
+//
+// MIT license
+// Copyright (C) 2025 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+
+#include <sycl/sycl.hpp>
+#include "dpct/helper.hpp"
+#include "common.hpp"
+#include "fattn-common.hpp"
+#include "fattn-tile.hpp"
+#include "fattn-vec.hpp"
+#include "fattn.hpp"
+
+
+#define FATTN_VEC_CASE(D, type_K, type_V)                                                                        \
+    {                                                                                                            \
+        const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \
+        const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \
+        if (Q->ne[0] == (D) && type_K_okay && type_V_okay) {                                                     \
+            ggml_sycl_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst);                                      \
+            return;                                                                                              \
+        }                                                                                                        \
+    }                                                                    \
+
+#define FATTN_VEC_CASES_ALL_D(type_K, type_V) \
+    FATTN_VEC_CASE( 64, type_K, type_V)       \
+    FATTN_VEC_CASE(128, type_K, type_V)       \
+    FATTN_VEC_CASE(256, type_K, type_V)       \
+
+static void ggml_sycl_flash_attn_ext_vec(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_tensor * Q = dst->src[0];
+    ggml_tensor * K = dst->src[1];
+    ggml_tensor * V = dst->src[2];
+
+#ifdef GGML_SYCL_FA_ALL_QUANTS
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_F16)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_F16)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_F16)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16)
+
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_Q4_0)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
+
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_Q4_1)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
+
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_Q5_0)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
+
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_Q5_1)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
+
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_Q8_0)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
+#else
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16,  GGML_TYPE_F16)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
+    FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
+#endif // GGML_SYCL_FA_ALL_QUANTS
+
+    GGML_ABORT("Not match KV type in vec");
+}
+
+// Best FlashAttention kernel for a specific GPU:
+enum best_fattn_kernel {
+    BEST_FATTN_KERNEL_NONE     =   0,
+    BEST_FATTN_KERNEL_VEC      = 100,
+    BEST_FATTN_KERNEL_TILE     = 200,
+};
+
+static best_fattn_kernel ggml_sycl_get_best_fattn_kernel(const int device, const ggml_tensor * dst) {
+    GGML_UNUSED(device);
+#ifndef SYCL_FLASH_ATTN
+    GGML_UNUSED(dst);
+    return BEST_FATTN_KERNEL_NONE;
+#endif// SYCL_FLASH_ATTN
+
+    if(!g_ggml_sycl_enable_flash_attention) return BEST_FATTN_KERNEL_NONE;
+
+    const ggml_tensor * KQV   = dst;
+    const ggml_tensor * Q     = dst->src[0];
+    const ggml_tensor * K     = dst->src[1];
+    const ggml_tensor * V     = dst->src[2];
+    const ggml_tensor * mask  = dst->src[3];
+
+    const int gqa_ratio = Q->ne[2] / K->ne[2];
+    GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
+
+    float max_bias = 0.0f;
+    memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
+
+    bool gqa_opt_applies = gqa_ratio >= 2 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
+    for (const ggml_tensor * t : {Q, K, V, mask}) {
+        if (t == nullptr || ggml_is_quantized(t->type)) {
+            continue;
+        }
+        for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
+            if (t->nb[i] % 16 != 0) {
+                gqa_opt_applies = false;
+                break;
+            }
+        }
+    }
+
+    switch (K->ne[0]) {
+        case  40:
+        case  64:
+        case  72:
+        case  80:
+        case  96:
+        case 128:
+        case 112:
+        case 256:
+            if (V->ne[0] != K->ne[0]) {
+                return BEST_FATTN_KERNEL_NONE;
+            }
+            break;
+        case 576:
+            if (V->ne[0] != 512) {
+                return BEST_FATTN_KERNEL_NONE;
+            }
+            if (!gqa_opt_applies) {
+                return BEST_FATTN_KERNEL_NONE;
+            }
+            break;
+        default:
+            return BEST_FATTN_KERNEL_NONE;
+    }
+
+#ifndef GGML_SYCL_FA_ALL_QUANTS
+    if (K->type != V->type) {
+        return BEST_FATTN_KERNEL_NONE;
+    }
+#endif // GGML_SYCL_FA_ALL_QUANTS
+
+    switch (K->type) {
+        case GGML_TYPE_F32:
+        case GGML_TYPE_F16:
+            break;
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+#ifndef GGML_SYCL_FA_ALL_QUANTS
+            return BEST_FATTN_KERNEL_NONE;
+#endif // GGML_SYCL_FA_ALL_QUANTS
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q8_0:
+            break;
+        default:
+            return BEST_FATTN_KERNEL_NONE;
+    }
+
+    if (mask && mask->ne[2] != 1) {
+        return BEST_FATTN_KERNEL_NONE;
+    }
+
+    // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes:
+    const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
+
+    // Todo: Use the XMX kernel if possible:
+
+    // If there are no tensor cores available, use the generic tile kernel:
+    if (can_use_vector_kernel) {
+        if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
+            if (Q->ne[1] == 1) {
+                if (!gqa_opt_applies) {
+                    return BEST_FATTN_KERNEL_VEC;
+                }
+            }
+        } else {
+            if (Q->ne[1] <= 2) {
+                return BEST_FATTN_KERNEL_VEC;
+            }
+        }
+    }
+    return BEST_FATTN_KERNEL_TILE;
+}
+
+void ggml_sycl_flash_attn_ext(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    ggml_sycl_set_device(ctx.device);
+    switch (ggml_sycl_get_best_fattn_kernel(ggml_sycl_get_device(), dst)) {
+        case BEST_FATTN_KERNEL_NONE:
+            GGML_ABORT("Not support Flash-Attention");
+        case BEST_FATTN_KERNEL_TILE:
+            ggml_sycl_flash_attn_ext_tile(ctx, dst);
+            break;
+        case BEST_FATTN_KERNEL_VEC:
+            ggml_sycl_flash_attn_ext_vec(ctx, dst);
+            break;
+    }
+}
+
+bool ggml_sycl_flash_attn_ext_supported(int device, const ggml_tensor * dst) {
+    return ggml_sycl_get_best_fattn_kernel(device, dst) != BEST_FATTN_KERNEL_NONE;
+}
diff --git a/ggml/src/ggml-sycl/fattn.hpp b/ggml/src/ggml-sycl/fattn.hpp
new file mode 100644 (file)
index 0000000..f2a8ffc
--- /dev/null
@@ -0,0 +1,22 @@
+//
+// MIT license
+// Copyright (C) 2025 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#ifndef GGML_SYCL_FATTN_HPP
+#define GGML_SYCL_FATTN_HPP
+
+#include "common.hpp"
+
+void ggml_sycl_flash_attn_ext(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+
+bool ggml_sycl_flash_attn_ext_supported(int device, const ggml_tensor * dst);
+
+#endif // GGML_SYCL_FATTN_HPP
index 0614d7e8f3ad660c201d3b1a863db89c4c14a32b..dfacde0af33f446c139378120ea9ca4c465e12fe 100644 (file)
@@ -62,6 +62,8 @@ int g_ggml_sycl_disable_graph = 0;
 int g_ggml_sycl_disable_dnn = 0;
 int g_ggml_sycl_prioritize_dmmv = 0;
 int g_ggml_sycl_use_async_mem_op = 0;
+int g_ggml_sycl_enable_flash_attention = 1;
+
 
 static ggml_sycl_device_info ggml_sycl_init() {
     ggml_sycl_device_info info = {};
@@ -94,11 +96,12 @@ static ggml_sycl_device_info ggml_sycl_init() {
 
         info.devices[i].cc =
             100 * prop.get_major_version() + 10 * prop.get_minor_version();
-        info.devices[i].nsm = prop.get_max_compute_units();
+        info.devices[i].nsm = prop.get_max_compute_units() / 16; //16: Number of Xe Cores
         info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu);
         info.devices[i].smpbo = prop.get_local_mem_size();
-
         info.max_work_group_sizes[i] = prop.get_max_work_group_size();
+        info.devices[i].max_wg_per_cu = info.max_work_group_sizes[i] / prop.get_max_compute_units();
+
     }
 
     for (int id = 0; id < info.device_count; ++id) {
@@ -211,7 +214,37 @@ static void ggml_check_sycl() try {
         g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
         g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0);
         g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
+
+#ifdef SYCL_FLASH_ATTN
+        g_ggml_sycl_enable_flash_attention = get_sycl_env("GGML_SYCL_ENABLE_FLASH_ATTN", 1);
+#else
+        g_ggml_sycl_enable_flash_attention = 0;
+#endif
+
         GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
+
+        GGML_LOG_INFO("Build with Macros:\n");
+#if defined(GGML_SYCL_FORCE_MMQ)
+        GGML_LOG_INFO("  GGML_SYCL_FORCE_MMQ: yes\n");
+#else
+        GGML_LOG_INFO("  GGML_SYCL_FORCE_MMQ: no\n");
+#endif
+#if defined(GGML_SYCL_F16)
+        GGML_LOG_INFO("  GGML_SYCL_F16: yes\n");
+#else
+        GGML_LOG_INFO("  GGML_SYCL_F16: no\n");
+#endif
+#if defined(GGML_SYCL_GRAPH)
+        GGML_LOG_INFO("  GGML_SYCL_GRAPH: yes\n");
+#else
+        GGML_LOG_INFO("  GGML_SYCL_GRAPH: no\n");
+#endif
+#if defined(GGML_SYCL_DNNL)
+        GGML_LOG_INFO("  GGML_SYCL_DNNL: yes\n");
+#else
+        GGML_LOG_INFO("  GGML_SYCL_DNNL: no\n");
+#endif
+
         GGML_LOG_INFO("Running with Environment Variables:\n");
         GGML_LOG_INFO("  GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
         GGML_LOG_INFO("  GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
@@ -226,16 +259,12 @@ static void ggml_check_sycl() try {
         GGML_LOG_INFO("  GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n");
 #endif
         GGML_LOG_INFO("  GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv);
-        GGML_LOG_INFO("Build with Macros:\n");
-#if defined(GGML_SYCL_FORCE_MMQ)
-        GGML_LOG_INFO("  GGML_SYCL_FORCE_MMQ: yes\n");
-#else
-        GGML_LOG_INFO("  GGML_SYCL_FORCE_MMQ: no\n");
-#endif
-#if defined(GGML_SYCL_F16)
-        GGML_LOG_INFO("  GGML_SYCL_F16: yes\n");
+
+#ifdef SYCL_FLASH_ATTN
+        GGML_LOG_INFO("  GGML_SYCL_ENABLE_FLASH_ATTN: %d\n", g_ggml_sycl_enable_flash_attention);
 #else
-        GGML_LOG_INFO("  GGML_SYCL_F16: no\n");
+        GGML_LOG_INFO("  GGML_SYCL_ENABLE_FLASH_ATTN: %d disabled by compile flag\n",
+            g_ggml_sycl_enable_flash_attention);
 #endif
 
 /* NOT REMOVE, keep it for next optimize for XMX.
@@ -3012,7 +3041,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
 
         }
 #if GGML_SYCL_DNNL
-        // oneDNN handles strided data and does not need overhead of get_to_fp16_nc_sycl
+        // oneDNN handles strided data and does not need overhead of ggml_get_to_fp16_nc_sycl
         const int64_t ne_src1 = src1->nb[last_str] * src1->ne[last_dim] / type_size_src1;
         src1_f16_alloc.alloc(ne_src1);
         const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
@@ -3021,7 +3050,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
 # else
         const int64_t ne_src1 = ggml_nelements(src1);
         src1_f16_alloc.alloc(ne_src1);
-        const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type);
+        const to_fp16_nc_sycl_t to_fp16_nc_sycl = ggml_get_to_fp16_nc_sycl(src1->type);
         GGML_ASSERT(to_fp16_nc_sycl != nullptr);
         to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue);
 #endif
@@ -4158,6 +4187,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
         case GGML_OP_ARANGE:
             ggml_sycl_arange(ctx, dst);
             break;
+        case GGML_OP_FLASH_ATTN_EXT:
+            ggml_sycl_flash_attn_ext(ctx, dst);
+            break;
         default:
             return false;
     }
@@ -4862,6 +4894,8 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
             return op->type == GGML_TYPE_F32;
         case GGML_OP_ARANGE:
             return op->type == GGML_TYPE_F32;
+        case GGML_OP_FLASH_ATTN_EXT:
+            return ggml_sycl_flash_attn_ext_supported(device, op);
         default:
             return false;
     }
index b6517374230a8ce2807edfca7ee53aa488bc3285..dc4dad1d37a8bc2b538a272455d420e02fe92d92 100644 (file)
@@ -73,4 +73,7 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA
 #define MUL_MAT_SRC1_COL_STRIDE 128
 
 #define QK_WARP_SIZE 32
+#define WARP_32_SIZE 32
+#define WARP_16_SIZE 16
+
 #endif // GGML_SYCL_PRESETS_HPP
index 15d92e5e04cd459288b3ea61a2cd333a40a56b0f..fdf9b843e01510f705deb2d06c33d2131f8d55aa 100644 (file)
@@ -102,7 +102,7 @@ static void soft_max_f32(const float *         x,
         max_val   = sycl::max(max_val, val);
     }
     // find the max value in the block
-    max_val = warp_reduce_max(max_val);
+    max_val = warp_reduce_max<WARP_SIZE>(max_val);
 
     if (block_size > WARP_SIZE) {
         if (warp_id == 0) {
@@ -116,7 +116,7 @@ static void soft_max_f32(const float *         x,
         item_ct1.barrier();
 
         max_val = buf_iw[lane_id];
-        max_val = warp_reduce_max(max_val);
+        max_val = warp_reduce_max<WARP_SIZE>(max_val);
     }
     float tmp = 0.0f; // partial sum
 
@@ -133,7 +133,7 @@ static void soft_max_f32(const float *         x,
         vals[col] = val;
     }
     // find the sum of exps in the block
-    tmp = warp_reduce_sum(tmp);
+    tmp = warp_reduce_sum<WARP_SIZE>(tmp);
     if (block_size > WARP_SIZE) {
         item_ct1.barrier();
         if (warp_id == 0) {
@@ -153,7 +153,7 @@ static void soft_max_f32(const float *         x,
         for (size_t i = 1; i < nreduce; i += 1) {
             tmp += buf_iw[lane_id + i * WARP_SIZE];
         }
-        tmp = warp_reduce_sum(tmp);
+        tmp = warp_reduce_sum<WARP_SIZE>(tmp);
     }
     if (sinks) {
         tmp += sycl::native::exp(sinks[i02] - max_val);
@@ -191,7 +191,7 @@ static void soft_max_back_f32(const float *grad, const float *dstf, float *dst,
         dgf_dot += dstf[col]*grad[col];
     }
 
-    dgf_dot = warp_reduce_sum(dgf_dot);
+    dgf_dot = warp_reduce_sum<WARP_SIZE>(dgf_dot);
 
     for (int col = tid; col < ncols; col += WARP_SIZE) {
         dst[col] = scale * (grad[col] - dgf_dot) * dstf[col];
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp
new file mode 100644 (file)
index 0000000..5c06d42
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.hpp"
+
+DECL_FATTN_TILE_CASE(112, 112);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp
new file mode 100644 (file)
index 0000000..f74e120
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.hpp"
+
+DECL_FATTN_TILE_CASE(128, 128);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp
new file mode 100644 (file)
index 0000000..b574fe9
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.hpp"
+
+DECL_FATTN_TILE_CASE(256, 256);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp
new file mode 100644 (file)
index 0000000..8c8fb69
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.hpp"
+
+DECL_FATTN_TILE_CASE(40, 40);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp
new file mode 100644 (file)
index 0000000..f218552
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.hpp"
+
+DECL_FATTN_TILE_CASE(576, 512);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp
new file mode 100644 (file)
index 0000000..99303a5
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.hpp"
+
+DECL_FATTN_TILE_CASE(64, 64);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp
new file mode 100644 (file)
index 0000000..5059276
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.hpp"
+
+DECL_FATTN_TILE_CASE(72, 72);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp
new file mode 100644 (file)
index 0000000..74f1ea5
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.hpp"
+
+DECL_FATTN_TILE_CASE(80, 80);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp b/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp
new file mode 100644 (file)
index 0000000..cefb46d
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.hpp"
+
+DECL_FATTN_TILE_CASE(96, 96);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp
new file mode 100644 (file)
index 0000000..32cf4f2
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp
new file mode 100644 (file)
index 0000000..a61a190
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_0);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp
new file mode 100644 (file)
index 0000000..63b74fb
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_1);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp
new file mode 100644 (file)
index 0000000..46e2d98
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_0);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp
new file mode 100644 (file)
index 0000000..7aabb6f
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_1);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp
new file mode 100644 (file)
index 0000000..148ea21
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp
new file mode 100644 (file)
index 0000000..4b169db
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_F16);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp
new file mode 100644 (file)
index 0000000..79f530b
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp
new file mode 100644 (file)
index 0000000..2f7db51
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp
new file mode 100644 (file)
index 0000000..9e3bf0b
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp
new file mode 100644 (file)
index 0000000..1808187
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp
new file mode 100644 (file)
index 0000000..1c387b0
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp
new file mode 100644 (file)
index 0000000..f005b37
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_F16);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp
new file mode 100644 (file)
index 0000000..3553b1c
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp
new file mode 100644 (file)
index 0000000..687ec56
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp
new file mode 100644 (file)
index 0000000..2663bfe
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp
new file mode 100644 (file)
index 0000000..641b7c7
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp
new file mode 100644 (file)
index 0000000..3d3181d
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp
new file mode 100644 (file)
index 0000000..85d5026
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_F16);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp
new file mode 100644 (file)
index 0000000..1e81401
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp
new file mode 100644 (file)
index 0000000..5425147
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp
new file mode 100644 (file)
index 0000000..d418c1f
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp
new file mode 100644 (file)
index 0000000..0f26cfa
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp
new file mode 100644 (file)
index 0000000..4fb9872
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp
new file mode 100644 (file)
index 0000000..85b79cd
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_F16);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp
new file mode 100644 (file)
index 0000000..7348323
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp
new file mode 100644 (file)
index 0000000..f19af2a
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp
new file mode 100644 (file)
index 0000000..d7075ba
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp
new file mode 100644 (file)
index 0000000..627f9a5
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp
new file mode 100644 (file)
index 0000000..23304ee
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp
new file mode 100644 (file)
index 0000000..95acb5d
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp
new file mode 100644 (file)
index 0000000..5e88f4b
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp
new file mode 100644 (file)
index 0000000..69f297f
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp
new file mode 100644 (file)
index 0000000..455842a
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp
new file mode 100644 (file)
index 0000000..f7ef739
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
diff --git a/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp b/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp
new file mode 100644 (file)
index 0000000..1c633bd
--- /dev/null
@@ -0,0 +1,7 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
index 43482b3672c1505efb2ccf43bd1928d1108b286e..9a267d85a0cc917d818e5222c401be3b0a955481 100644 (file)
@@ -650,6 +650,19 @@ static __dpct_inline__ float vec_dot_q8_0_q8_1_impl(const int *v, const int *u,
     return d8_0*d8_1 * sumi;
 }
 
+template <typename T, int vdr>
+static __dpct_inline__ T vec_dot_q8_0_q8_1_impl(const int * v, const int * u, const T & d8_0, const T & d8_1) {
+    int sumi = 0;
+
+#pragma unroll
+    for (int i = 0; i < vdr; ++i) {
+        // SIMD dot product of quantized values
+        sumi = ggml_sycl_dp4a(v[i], u[i], sumi);
+    }
+
+    return d8_0*d8_1 * ((T) sumi);
+}
+
 template <int vdr>
 static __dpct_inline__ float vec_dot_q8_1_q8_1_impl(const int *v, const int *u,
                                                     const sycl::half2 &dm8,