file(GLOB GGML_HEADERS_SYCL "*.hpp")
file(GLOB GGML_SOURCES_SYCL "*.cpp")
+file(GLOB SRCS "template-instances/fattn-tile*.cpp")
+list(APPEND GGML_SOURCES_SYCL ${SRCS})
+file(GLOB SRCS "template-instances/fattn-vec*.cpp")
+list(APPEND GGML_SOURCES_SYCL ${SRCS})
+
target_sources(ggml-sycl PRIVATE ${GGML_HEADERS_SYCL} ${GGML_SOURCES_SYCL})
if (WIN32)
endif()
if (GGML_SYCL_GRAPH)
+ message(STATUS "find GGML_SYCL_GRAPH")
target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_GRAPH)
endif()
#include "dequantize.hpp"
#include "dmmv.hpp"
#include "element_wise.hpp"
+#include "fattn.hpp"
#include "gla.hpp"
#include "im2col.hpp"
#include "mmq.hpp"
#include <string>
#include "dpct/helper.hpp"
+#include "ggml.h"
+#include "ggml-impl.h"
#include "ggml-sycl.h"
#include "presets.hpp"
#include "sycl_hw.hpp"
+namespace syclexp = sycl::ext::oneapi::experimental;
#if GGML_SYCL_DNNL
#include "dnnl.hpp"
#define GGML_COMMON_DECL_SYCL
#define GGML_COMMON_IMPL_SYCL
+#define SYCL_FLASH_ATTN //remove it to disable FLASH_ATTENTION in building.
+#define SYCL_FAST_FP16 //don't change. remove it will break fattn-tile.hpp building
+
/* suppress warning spam */
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wnested-anon-types"
extern int g_ggml_sycl_debug;
extern int g_ggml_sycl_disable_optimize;
extern int g_ggml_sycl_prioritize_dmmv;
+extern int g_ggml_sycl_enable_flash_attention;
+
#if defined(__clang__) && __has_builtin(__builtin_expect)
// Hint the optimizer to pipeline the more likely following instruction in branches
int get_current_device_id();
+inline int ggml_sycl_get_device() {
+ return get_current_device_id();
+}
+
inline dpct::err0 ggml_sycl_set_device(const int device) try {
int current_device_id;
SYCL_CHECK(CHECK_TRY_ERROR(current_device_id = get_current_device_id()));
};
struct sycl_device_info {
- int cc; // compute capability
+ int cc; // compute capability
int nsm; // number of streaming multiprocessors (CUDA) maps to the maximum
// number of compute units on a SYCL device.
// size_t smpb; // max. shared memory per block
size_t smpbo; // max. shared memory per block (with opt-in)
+ int warp_size; // max sub_group_size of SYCL
+ int max_wg_per_cu; // max work groups per compute unit - refer to
+ // cudaOccupancyMaxActiveBlocksPerMultiprocessor
bool vmm; // virtual memory support
size_t total_vram;
//sycl_hw_info hw_info; \\ device id and aarch, currently not used
return a;
}
-template <int width = WARP_SIZE>
+/* use WARP_SIZE or WARP_32_SIZE*/
+template <int width>
static __dpct_inline__ int warp_reduce_sum(int x) {
return sycl::reduce_over_group(
sycl::ext::oneapi::this_work_item::get_sub_group(), x, sycl::plus<>());
}
-template <int width = WARP_SIZE>
+/* use WARP_SIZE or WARP_32_SIZE*/
+template <int width>
static __dpct_inline__ float warp_reduce_sum(float x) {
#pragma unroll
for (int offset = width / 2; offset > 0; offset >>= 1) {
return x;
}
-template <int width = WARP_SIZE>
+/* use WARP_SIZE or WARP_32_SIZE*/
+template <int width>
+static __dpct_inline__ float warp_reduce_sum(float x, const sycl::nd_item<3>& item_ct1) {
+#pragma unroll
+ for (int offset = width / 2; offset > 0; offset >>= 1) {
+ x += dpct::permute_sub_group_by_xor(
+ item_ct1.get_sub_group(), x, offset);
+ }
+ return x;
+}
+
+/* use WARP_SIZE or WARP_32_SIZE*/
+template <int width>
static __dpct_inline__ sycl::float2 warp_reduce_sum(sycl::float2 a) {
#pragma unroll
for (int offset = width / 2; offset > 0; offset >>= 1) {
return a;
}
-template <int width = WARP_SIZE>
+/* use WARP_SIZE or WARP_32_SIZE*/
+template <int width>
static __dpct_inline__ sycl::half2 warp_reduce_sum(sycl::half2 a) {
#pragma unroll
for (int offset = width / 2; offset > 0; offset >>= 1) {
return WARP_SIZE;
}
-template <int width = WARP_SIZE>
+/* use WARP_SIZE or WARP_32_SIZE*/
+template <int width>
+static __dpct_inline__ int warp_reduce_all(int x) {
+ if (width == ggml_sycl_get_physical_warp_size()) {
+ return sycl::all_of_group(
+ sycl::ext::oneapi::this_work_item::get_sub_group(),
+ (~0xffffffff &
+ (0x1 << sycl::ext::oneapi::this_work_item::get_sub_group()
+ .get_local_linear_id())) ||
+ x);
+ } else {
+#pragma unroll
+ for (int offset = width / 2; offset > 0; offset >>= 1) {
+ x = dpct::permute_sub_group_by_xor(
+ sycl::ext::oneapi::this_work_item::get_sub_group(), x,
+ offset, width) &&
+ x;
+ }
+ return x;
+ }
+}
+
+/* use WARP_SIZE or WARP_32_SIZE*/
+template <int width>
+static __dpct_inline__ int warp_reduce_any(int x) {
+ if (width == ggml_sycl_get_physical_warp_size()) {
+ return sycl::any_of_group(
+ sycl::ext::oneapi::this_work_item::get_sub_group(),
+ (0xffffffff &
+ (0x1 << sycl::ext::oneapi::this_work_item::get_sub_group()
+ .get_local_linear_id())) &&
+ x);
+ } else {
+#pragma unroll
+ for (int offset = width / 2; offset > 0; offset >>= 1) {
+ x = dpct::permute_sub_group_by_xor(
+ sycl::ext::oneapi::this_work_item::get_sub_group(), x,
+ offset, width) ||
+ x;
+ }
+ return x;
+ }
+}
+
+/* use WARP_SIZE or WARP_32_SIZE*/
+template <int width>
static __dpct_inline__ float warp_reduce_max(float x) {
#pragma unroll
for (int offset = width / 2; offset > 0; offset >>= 1) {
return sycl::uint3(mp, L, d);
}
+// Maximum number of bytes that can be copied in a single instruction.
+// Set by test result.
+static constexpr int ggml_sycl_get_max_cpy_bytes() {
+ return 16;
+}
+
+// Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes.
+template <int nbytes, int alignment = 0>
+static __dpct_inline__ void ggml_sycl_memcpy_1(void * dst, const void * src) {
+ if constexpr (alignment != 0) {
+ static_assert(nbytes % alignment == 0, "bad alignment");
+ }
+ constexpr int nb_per_cpy = alignment == 0 ? nbytes : alignment;
+
+#pragma unroll
+ for (int i = 0; i < nbytes/nb_per_cpy; ++i) {
+ if constexpr (nb_per_cpy == 1) {
+ ((char *) dst)[i] = ((const char *) src)[i];
+ } else if constexpr (nb_per_cpy == 2) {
+ ((short *) dst)[i] = ((const short *) src)[i];
+ } else if constexpr (nb_per_cpy == 4) {
+ ((int *) dst)[i] = ((const int *) src)[i];
+ } else if constexpr (nb_per_cpy == 8) {
+ ((sycl::int2 *) dst)[i] = ((const sycl::int2 *) src)[i];
+ } else if constexpr (nb_per_cpy == 16) {
+ ((sycl::int4 *) dst)[i] = ((const sycl::int4 *) src)[i];
+ } else {
+ static_assert(nbytes == 0 && nbytes == -1, "bad nbytes");
+ }
+ }
+}
+template <typename T>
+sycl::half2 __dpct_inline__ make_half2( T x, T y) {
+ sycl::half2 res(static_cast<sycl::half>(x),static_cast<sycl::half>(y));
+ return res;
+}
static __dpct_inline__ uint32_t fastdiv(uint32_t n, const sycl::uint3 fastdiv_values) {
const uint32_t hi = sycl::mul_hi<unsigned>(n, fastdiv_values.x());
}
+template <typename T>
+sycl::float2 __dpct_inline__ make_float2( T x, T y) {
+ sycl::float2 res(static_cast<float>(x),static_cast<float>(y));
+ return res;
+}
+
+sycl::float2 __dpct_inline__ __half22float2(sycl::half2 &H) {
+ sycl::float2 float2_value(static_cast<float>(H.x()), static_cast<float>(H.y()));
+ return float2_value;
+}
+
static __dpct_inline__ sycl::uint2 fast_div_modulo(uint32_t n, const sycl::uint3 fastdiv_values) {
const uint32_t div_val = fastdiv(n, fastdiv_values);
const uint32_t mod_val = n - div_val * fastdiv_values.z();
return result;
}
+sycl::float2 __dpct_inline__ __half22float2(const sycl::half2 &H) {
+ sycl::float2 float2_value(static_cast<float>(H.x()), static_cast<float>(H.y()));
+ return float2_value;
+}
+
+float __dpct_inline__ __half2float(sycl::half H) {
+ return static_cast<float>(H);
+}
+
+static __dpct_inline__ void ggml_sycl_mad(float & acc, const float v, const float u) {
+ acc += v*u;
+}
+
+static __dpct_inline__ void ggml_sycl_mad(float & acc, const sycl::float2 v, const sycl::float2 u) {
+ acc += v.x() * u.x();
+ acc += v.y() * u.y();
+}
+
+static __dpct_inline__ void ggml_sycl_mad(float & acc, const sycl::half2 v, const sycl::half2 u) {
+#ifdef GGML_SYCL_F16
+ const sycl::float2 tmp = (v * u).template convert<float, sycl::rounding_mode::automatic>();
+ acc += tmp.x() + tmp.y();
+#else
+ const sycl::float2 tmpv = __half22float2(v);
+ const sycl::float2 tmpu = __half22float2(u);
+ acc += tmpv.x() * tmpu.x();
+ acc += tmpv.y() * tmpu.y();
+#endif // GGML_SYCL_F16
+}
+
+static __dpct_inline__ void ggml_sycl_mad(sycl::half2 & acc, const sycl::half2 v, const sycl::half2 u) {
+#ifdef GGML_SYCL_F16
+ acc += v*u;
+#else
+ const sycl::float2 tmpv = __half22float2(v);
+ const sycl::float2 tmpu = __half22float2(u);
+ sycl::float2 tmpacc = __half22float2(acc);
+ // tmpacc.x += tmpv.x() * tmpu.x();
+ // tmpacc.y += tmpv.y() * tmpu.y();
+ sycl::float2 tmp1(tmpacc.x() + tmpv.x() * tmpu.x(), tmpacc.y() + tmpv.y() * tmpu.y());
+ acc = make_half2(tmp1.x(), tmp1.y());
+#endif // GGML_SYCL_F16
+}
+
+template <int n>
+struct ggml_sycl_unroll {
+ template <typename Func, typename... Args>
+ void operator()(const Func & f, Args... args) const {
+ f(n - 1, args...);
+ ggml_sycl_unroll<n - 1>{}(f, args...);
+ }
+};
+
+template <>
+struct ggml_sycl_unroll<1> {
+ template <typename Func, typename... Args>
+ void operator()(const Func & f, Args... args) const {
+ f(0, args...);
+ }
+};
+
+static __dpct_inline__ sycl::half2 ggml_sycl_hmax2(const sycl::half2 a, const sycl::half2 b) {
+ sycl::half2 ret;
+ reinterpret_cast<sycl::half &>(ret.x()) =
+ sycl::vec<float, 1>(sycl::fmax(a[0], b[0])).convert<sycl::half, sycl::rounding_mode::automatic>()[0];
+ reinterpret_cast<sycl::half &>(ret.y()) =
+ sycl::vec<float, 1>(sycl::fmax(a[1], b[1])).convert<sycl::half, sycl::rounding_mode::automatic>()[0];
+ return ret;
+}
+
+static __dpct_inline__ sycl::half ggml_sycl_hmax(const sycl::half a, const sycl::half b) {
+ return sycl::vec<float, 1>(
+ sycl::fmax(sycl::vec<sycl::half, 1>(a).convert<float, sycl::rounding_mode::automatic>()[0],
+ sycl::vec<sycl::half, 1>(b).convert<float, sycl::rounding_mode::automatic>()[0]))
+ .convert<sycl::half, sycl::rounding_mode::automatic>()[0];
+}
+
+static __dpct_inline__ uint32_t __hgt2_mask(const sycl::half2 a, const sycl::half2 b) {
+ const uint32_t mask_low = 0x0000FFFF * (float(a[0]) > float(b[0]));
+ const uint32_t mask_high = 0xFFFF0000 * (float(a[1]) > float(b[1]));
+ return mask_low | mask_high;
+}
+
+static __dpct_inline__ uint32_t fastmodulo(uint32_t n, const sycl::uint3 fastdiv_values) {
+ // expects fastdiv_values to contain <mp, L, divisor> in <x, y, z> (see init_fastdiv_values)
+ return n - fastdiv(n, fastdiv_values) * fastdiv_values.z();
+}
+
+static bool fast_fp16_available(const int cc) {
+ GGML_UNUSED(cc);
+ return true; //Intel GPUs always support FP16.
+}
#endif // GGML_SYCL_COMMON_HPP
});
}
+template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
+static void dequantize_block_nc(const void * __restrict__ vx, dst_t * __restrict__ y,
+ const int64_t ne00, const int64_t ne01, const int64_t ne02,
+ const int64_t s01, const int64_t s02, const int64_t s03) {
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+ const int64_t i00 = 2 * (int64_t(item_ct1.get_local_range(2)) * item_ct1.get_group(2) + item_ct1.get_local_id(2));
+
+ if (i00 >= ne00) {
+ return;
+ }
+
+ const int64_t i01 = item_ct1.get_group(1);
+ const int64_t i02 = item_ct1.get_group(0) % ne02;
+ const int64_t i03 = item_ct1.get_group(0) / ne02;
+
+ const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
+
+ const int64_t ib = ibx0 + i00/qk; // block index
+ const int64_t iqs = (i00%qk)/qr; // quant index
+ const int64_t iybs = i00 - i00%qk; // y block start index
+ const int64_t y_offset = qr == 1 ? 1 : qk/2;
+
+ // dequantize
+ #ifdef GGML_SYCL_F16
+ sycl::half2 v;
+ #else
+ sycl::float2 v;
+ #endif
+
+ dequantize_kernel(vx, ib, iqs, v);
+
+ const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs;
+ y[iy0 + 0] = ggml_sycl_cast<dst_t>(v.x());
+ y[iy0 + y_offset] = ggml_sycl_cast<dst_t>(v.y());
+}
+
+
+template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
+static void dequantize_block_nc_sycl(const void * vx,
+ dst_t * y,
+ const int64_t ne00,
+ const int64_t ne01,
+ const int64_t ne02,
+ const int64_t ne03,
+ const int64_t s01,
+ const int64_t s02,
+ const int64_t s03,
+ dpct::queue_ptr stream) {
+ const dpct::dim3 num_blocks((ne00 + 2 * SYCL_DEQUANTIZE_BLOCK_SIZE - 1) / (2 * SYCL_DEQUANTIZE_BLOCK_SIZE), ne01,
+ ne02 * ne03);
+ stream->parallel_for(sycl::nd_range<3>(num_blocks * sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE),
+ sycl::range<3>(1, 1, SYCL_DEQUANTIZE_BLOCK_SIZE)),
+ [=](sycl::nd_item<3> item_ct1) {
+ GGML_UNUSED(item_ct1);
+ dequantize_block_nc<qk, qr, dequantize_kernel>(vx, y, ne00, ne01, ne02, s01, s02, s03);
+ });
+}
template <typename src_t, typename dst_t>
static void convert_unary_nc(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01,
const int64_t ne02, const int64_t s01, const int64_t s02, const int64_t s03,
}
}
-to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type) {
+
+to_fp16_nc_sycl_t ggml_get_to_fp16_nc_sycl(ggml_type type) {
switch (type) {
case GGML_TYPE_F32:
return convert_unary_nc_sycl<float>;
case GGML_TYPE_BF16:
return convert_unary_nc_sycl<sycl::ext::oneapi::bfloat16>;
#endif
+ case GGML_TYPE_Q4_0:
+ return dequantize_block_nc_sycl<QK4_0, QR4_0, dequantize_q4_0>;
+ case GGML_TYPE_Q4_1:
+ return dequantize_block_nc_sycl<QK4_1, QR4_1, dequantize_q4_1>;
+ case GGML_TYPE_Q5_0:
+ return dequantize_block_nc_sycl<QK5_0, QR5_0, dequantize_q5_0>;
+ case GGML_TYPE_Q5_1:
+ return dequantize_block_nc_sycl<QK5_1, QR5_1, dequantize_q5_1>;
+ case GGML_TYPE_Q8_0:
+ return dequantize_block_nc_sycl<QK8_0, QR8_0, dequantize_q8_0>;
default:
return nullptr;
}
int64_t s01, int64_t s02, int64_t s03, dpct::queue_ptr queue);
typedef to_t_nc_sycl_t<sycl::half> to_fp16_nc_sycl_t;
-to_fp16_nc_sycl_t get_to_fp16_nc_sycl(ggml_type type);
+to_fp16_nc_sycl_t ggml_get_to_fp16_nc_sycl(ggml_type type);
+
+template<typename dst_t, typename src_t>
+ inline dst_t ggml_sycl_cast(src_t x) {
+ if constexpr (std::is_same_v<dst_t, src_t>) {
+ return x;
+ } else if constexpr (std::is_same_v<dst_t, sycl::ext::oneapi::bfloat16>) {
+ return sycl::ext::oneapi::bfloat16(float(x));
+ } else if constexpr (std::is_same_v<src_t, sycl::ext::oneapi::bfloat16>) {
+ return static_cast<float>(x);
+ } else if constexpr(std::is_same_v<dst_t, int32_t>) {
+ return int32_t(x);
+ } else {
+ return float(x);
+ }
+}
#endif // GGML_SYCL_CONVERT_HPP
nequal += xi == yi;
}
- nequal = warp_reduce_sum(nequal);
+ nequal = warp_reduce_sum<WARP_SIZE>(nequal);
if (item_ct1.get_local_id(2) != 0) {
return;
return 0;
}
+ template <int n_nondefault_params, int n_default_params, typename T>
+ class args_selector;
+
+ /// args_selector is a helper class for extracting arguments from an
+ /// array of pointers to arguments or buffer of arguments to pass to a
+ /// kernel function.
+ ///
+ /// \param R(Ts...) The type of the kernel
+ /// \param n_nondefault_params The number of nondefault parameters of the
+ /// kernel (excluding parameters that like sycl::nd_item, etc.) \param
+ /// n_default_params The number of default parameters of the kernel
+ ///
+ /// Example usage:
+ /// With the following kernel:
+ /// void foo(sycl::float2 *x, int n, sycl::nd_item<3> item_ct1, float
+ /// f=.1) {}
+ /// and with the declaration:
+ /// args_selector<2, 1, decltype(foo)> selector(kernelParams, extra);
+ /// we have:
+ /// selector.get<0>() returns a reference to sycl::float*,
+ /// selector.get<1>() returns a reference to int,
+ /// selector.get<2>() returns a reference to float
+ template <int n_nondefault_params, int n_default_params, typename R,
+ typename... Ts>
+ class args_selector<n_nondefault_params, n_default_params, R(Ts...)> {
+ private:
+ void **kernel_params;
+ char *args_buffer;
+
+ template <int i> static constexpr int account_for_default_params() {
+ constexpr int n_total_params = sizeof...(Ts);
+ if constexpr (i >= n_nondefault_params) {
+ return n_total_params - n_default_params +
+ (i - n_nondefault_params);
+ } else {
+ return i;
+ }
+ }
+
+ public:
+ /// Get the type of the ith argument of R(Ts...)
+ /// \param [in] i Index of parameter to get
+ /// \returns Type of ith parameter
+ template <int i>
+ using arg_type = std::tuple_element_t<account_for_default_params<i>(),
+ std::tuple<Ts...>>;
+ static constexpr int params_num = sizeof...(Ts);
+
+ private:
+ template <int i> static constexpr int get_offset() {
+ if constexpr (i == 0) {
+ // we can assume args_buffer is properly aligned to the
+ // first argument
+ return 0;
+ } else {
+ constexpr int prev_off = get_offset<i - 1>();
+ constexpr int prev_past_end =
+ prev_off + sizeof(arg_type<i - 1>);
+ using T = arg_type<i>;
+ // is the past-the-end of the i-1st element properly aligned
+ // with the ith element's alignment?
+ if constexpr (prev_past_end % alignof(T) == 0) {
+ return prev_past_end;
+ }
+ // otherwise bump prev_past_end to match alignment
+ else {
+ return prev_past_end +
+ (alignof(T) - (prev_past_end % alignof(T)));
+ }
+ }
+ }
+
+ static char *get_args_buffer(void **extra) {
+ if (!extra)
+ return nullptr;
+ for (; (std::size_t)*extra != 0; ++extra) {
+ if ((std::size_t)*extra == 1) {
+ return static_cast<char *>(*(extra + 1));
+ }
+ }
+ return nullptr;
+ }
+
+ public:
+ /// If kernel_params is nonnull, then args_selector will
+ /// extract arguments from kernel_params. Otherwise, it
+ /// will extract them from extra.
+ /// \param [in] kernel_params Array of pointers to arguments
+ /// a or null pointer.
+ /// \param [in] extra Array containing pointer to argument buffer.
+ args_selector(void **kernel_params, void **extra)
+ : kernel_params(kernel_params),
+ args_buffer(get_args_buffer(extra)) {}
+
+ /// Get a reference to the ith argument extracted from kernel_params
+ /// or extra.
+ /// \param [in] i Index of argument to get
+ /// \returns Reference to the ith argument
+ template <int i> arg_type<i> &get() {
+ if (kernel_params) {
+ return *static_cast<arg_type<i> *>(kernel_params[i]);
+ } else {
+ return *reinterpret_cast<arg_type<i> *>(args_buffer +
+ get_offset<i>());
+ }
+ }
+ }; // COPY from DPCT head file
+ // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/util.hpp
+
+ /// Utility class for launching SYCL kernels through kernel
+ /// function wrapper.
+ /// For example:
+ /// A SYCL kernel function:
+ /// void kernel_func(int *ptr, sycl::nd_item<3> item);
+ /// Kernel function wrapper:
+ /// void kernel_func_wrapper(int *ptr) {
+ /// sycl::queue queue = *dpct::kernel_launcher::_que;
+ /// unsigned int localMemSize = dpct::kernel_launcher::_local_mem_size;
+ /// sycl::nd_range<3> nr = dpct::kernel_launcher::_nr;
+ /// queue.parallel_for(
+ /// nr,
+ /// [=](sycl::nd_item<3> item_ct1) {
+ /// kernel_func(ptr, item_ct1);
+ /// });
+ /// }
+ /// Then launch the kernel through wrapper like:
+ /// typedef void(*fpt)(int *);
+ /// fpt fp = kernel_func_wrapper;
+ /// dpct::kernel_launcher::launch(fp, dpct::dim3(1), dpct::dim3(1), 0, 0,
+ /// device_ptr);
+ /// If the origin function type is erased, then need to register it first:
+ /// void *fp = (void *)wrapper_register(&kernel_func_wrapper).get();
+ /// dpct::kernel_launcher::launch(fp, dpct::dim3(1), dpct::dim3(1), args,
+ /// 0, 0);
+ class kernel_launcher {
+ template <typename FuncT, typename ArgSelector, std::size_t... Index>
+ static void launch_helper(FuncT &&func, ArgSelector &selector,
+ std::index_sequence<Index...>) {
+ func(selector.template get<Index>()...);
+ }
+ static void set_execution_config(dim3 group_range, dim3 local_range,
+ unsigned int local_mem_size,
+ queue_ptr que) {
+ if (que) {
+ _que = que;
+ } else {
+ _que = &get_default_queue();
+ }
+ _nr = sycl::nd_range<3>(
+ static_cast<sycl::range<3>>(group_range * local_range),
+ static_cast<sycl::range<3>>(local_range));
+ _local_mem_size = local_mem_size;
+
+
+ };
+ static inline std::mutex kernel_function_ptr_map_mutex;
+
+ public:
+ /// Variables for storing execution configuration.
+ static inline thread_local sycl::queue *_que = nullptr;
+ static inline thread_local sycl::nd_range<3> _nr = sycl::nd_range<3>();
+ static inline thread_local unsigned int _local_mem_size = 0;
+ /// Map for retrieving launchable functor from a raw pointer.
+ static inline std::map<
+ const void *,
+ std::function<void(dim3, dim3, void **, unsigned int, queue_ptr)>>
+ kernel_function_ptr_map = {};
+
+ /// Registers a kernel function pointer with a corresponding launchable
+ /// functor.
+ /// \param [in] func Pointer to the kernel function.
+ /// \param [in] launcher Functor to handle kernel invocation.
+ static void register_kernel_ptr(
+ const void *func,
+ std::function<void(dim3, dim3, void **, unsigned int, queue_ptr)>
+ launcher) {
+ std::lock_guard<std::mutex> lock(kernel_function_ptr_map_mutex);
+ kernel_function_ptr_map[func] = std::move(launcher);
+ }
+ /// Launches a kernel function with arguments provided directly through
+ /// kernel function wrapper.
+ /// \tparam FuncT Type of the kernel function wrapper.
+ /// \tparam ArgsT Types of kernel arguments.
+ /// \param [in] func Pointer to the kernel function wrapper.
+ /// \param [in] group_range SYCL group range.
+ /// \param [in] local_range SYCL local range.
+ /// \param [in] local_mem_size The size of local memory required by the
+ /// kernel function. \param [in] que SYCL queue used to execute kernel.
+ /// \param [in] args Kernel arguments.
+ template <typename FuncT, typename... ArgsT>
+ static std::enable_if_t<std::is_invocable_v<FuncT *, ArgsT...>, void>
+ launch(FuncT *func, dim3 group_range, dim3 local_range,
+ unsigned int local_mem_size, queue_ptr que, ArgsT... args) {
+ set_execution_config(group_range, local_range, local_mem_size, que);
+ func(args...);
+ }
+ /// Launches a kernel function through registered kernel function
+ /// wrapper. \param [in] func Pointer to the registered kernel function
+ /// wrapper. \param [in] group_range SYCL group range. \param [in]
+ /// local_range SYCL local range. \param [in] args Array of pointers to
+ /// kernel arguments. \param [in] local_mem_size The size of local
+ /// memory required by the kernel function. \param [in] que SYCL queue
+ /// used to execute kernel.
+ static void launch(const void *func, dim3 group_range, dim3 local_range,
+ void **args, unsigned int local_mem_size,
+ queue_ptr que) {
+ std::lock_guard<std::mutex> lock(kernel_function_ptr_map_mutex);
+ auto Iter = kernel_function_ptr_map.find(func);
+ if (Iter == kernel_function_ptr_map.end()) {
+ throw std::runtime_error("dpct::launch() : no registered "
+ "kernel function wrapper found.");
+ }
+ (Iter->second)(group_range, local_range, args, local_mem_size, que);
+ }
+ /// Launches a kernel function with packed arguments through kernel
+ /// function wrapper.
+ /// \tparam FuncT Type of the kernel function wrapper.
+ /// \param [in] func Pointer to the kernel function wrapper.
+ /// \param [in] group_range SYCL group range.
+ /// \param [in] local_range SYCL local range.
+ /// \param [in] args Array of pointers to kernel arguments.
+ /// \param [in] local_mem_size The size of local memory required by the
+ /// kernel function. \param [in] que SYCL queue used to execute kernel.
+ template <typename FuncT>
+ static std::enable_if_t<std::is_function_v<FuncT>, void>
+ launch(FuncT *func, dim3 group_range, dim3 local_range, void **args,
+ unsigned int local_mem_size, queue_ptr que) {
+ constexpr size_t p_num = args_selector<0, 0, FuncT>::params_num;
+ set_execution_config(group_range, local_range, local_mem_size, que);
+ args_selector<p_num, p_num, FuncT> selector(args, nullptr);
+ launch_helper(func, selector, std::make_index_sequence<p_num>{});
+ }
+ }; // COPY from DPCT head file
+ // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/kernel.hpp
+
+ // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/util.hpp
+ template <typename T>
+ T select_from_sub_group(
+ sycl::sub_group g,
+ T x,
+ int remote_local_id,
+ int logical_sub_group_size = 32) {
+ unsigned int start_index = g.get_local_linear_id() /
+ logical_sub_group_size *
+ logical_sub_group_size;
+ return sycl::select_from_group(
+ g, x, start_index + remote_local_id % logical_sub_group_size);
+ }
+
+ // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/math.hpp
+ template <typename T>
+ void ldmatrix(uintptr_t addr, T* m, bool trans = false, unsigned mat = 0) {
+ auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();
+ int lane = sg.get_local_linear_id();
+
+ int lane_group8_row = lane / 8;
+ int lane_group8_col = lane % 8;
+
+ if (!trans) {
+ // calculate the source lane
+ int src_lane = 2 * lane_group8_row;
+ if (lane_group8_col >= 4)
+ src_lane += 1;
+
+ // Broadcast the address from the source lane
+ auto recv_addr_uintp =
+ dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane);
+
+ // Cast the received address from uintptr_t to the type of 'm'
+ auto recv_addr = reinterpret_cast<T*>(recv_addr_uintp);
+
+ // Non-transposed load
+ *m = recv_addr[lane_group8_col % 4];
+ } else {
+ // calculate the source lane
+ int src_lane = (lane % 4) * 2;
+
+ // Broadcast the address from the source lane
+ auto recv_addr_uintp_1 =
+ dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane);
+ auto recv_addr_uintp_2 =
+ dpct::select_from_sub_group(sg, addr, mat * 8 + src_lane + 1);
+
+ // Cast the received address from uintptr_t to 'half *'
+ auto recv_addr_1 = reinterpret_cast<sycl::half*>(recv_addr_uintp_1);
+ auto recv_addr_2 = reinterpret_cast<sycl::half*>(recv_addr_uintp_2);
+
+ // Transposed load
+ int index = lane / 4;
+ sycl::half val0 = recv_addr_1[index];
+ sycl::half val1 = recv_addr_2[index];
+
+ // Combine the two 16-bits into one 32-bit value
+ sycl::half2 val = sycl::half2(val0, val1);
+ *m = *reinterpret_cast<T*>(&val);
+ }
+ }
+
+ template <typename T>
+ void ldmatrix(uintptr_t addr, T* m1, T* m2, bool trans = false) {
+ // Load 1st matrix
+ ldmatrix(addr, m1, trans, 0);
+ // Load 2nd matrix
+ ldmatrix(addr, m2, trans, 1);
+ }
+
+ template <typename T>
+ void ldmatrix(
+ uintptr_t addr, T* m1, T* m2, T* m3, T* m4, bool trans = false) {
+ // Load 1st matrix
+ ldmatrix(addr, m1, trans, 0);
+ // Load 2nd matrix
+ ldmatrix(addr, m2, trans, 1);
+ // Load 3rd matrix
+ ldmatrix(addr, m3, trans, 2);
+ // Load 4th matrix
+ ldmatrix(addr, m4, trans, 3);
+ }
+
+ // /opt/intel/oneapi/dpcpp-ct/latest/include/dpct/math.hpp
+
+ /// A helper struct that defines the pack type for the input matrix
+ /// fragments
+ /// of mma() function based on the type of input matrix fragments.
+ /// The MMAType struct is specialized for different types of input matrices.
+ /// Currently, the specialization for f16, bf16 and s8 types is defined
+ /// below. \tparam [in] T The type of the input matrix fragments
+ template <typename T>
+ struct MMAType {
+ using PackType = uint32_t;
+ };
+
+ /// Each work item of a sub-group (limited to size 32) calling this function
+ /// calculates a subset fragment for the output matrix D using MAD operation
+ /// on A, B & C matrix fragments (D = A * B + C). Current supported shapes &
+ /// types:
+ /// - m8n8k4 (f32.f16.f16.f32)
+ /// - m8n8k16 (s32.s8.s8.s32)
+ /// - m16n8k8 (f32.f16.f16.f32 & f32.bf16.bf16.f32)
+ /// - m16n8k16 (f32.f16.f16.f32 & s32.s8.s8.s32)
+ /// - m16n8k32 (s32.s8.s8.s32)
+ /// Here, m, n & k define the shapes of A, B & C matrices respectively
+ /// (A = [m x k], B = [k x n], C = [m x n]).
+ /// \tparam [in] M The rows of A, C & D matrices
+ /// \tparam [in] N The columns of B, C, D matrices
+ /// \tparam [in] K The columns & rows of A & B matrices respectively
+ /// \tparam [in] ABType The type of the input matrix (A & B) fragment
+ /// \tparam [in] CDType The type of the output matrix (C & D) fragment
+ /// \param [out] d_mat_frag The fragment of the output matrix D to store the
+ /// result of A * B + C
+ /// \param [in] a_mat_frag The fragment of the input matrix A to be
+ /// multiplied with B matrix fragment \param [in] b_mat_frag The fragment of
+ /// the input matrix B to be multiplied with A matrix fragment \param [in]
+ /// c_mat_frag The fragment of the input matrix C to be added with the
+ /// result of A * B fragments
+ template <int M, int N, int K, typename ABType, typename CDType>
+ void mma(
+ volatile void** d_mat_frag,
+ void* a_mat_frag,
+ void* b_mat_frag,
+ void* c_mat_frag) {
+ auto d = reinterpret_cast<volatile CDType**>(d_mat_frag);
+ auto a =
+ reinterpret_cast<typename MMAType<ABType>::PackType*>(a_mat_frag);
+ auto b =
+ reinterpret_cast<typename MMAType<ABType>::PackType*>(b_mat_frag);
+ auto c = reinterpret_cast<CDType*>(c_mat_frag);
+
+ auto sg = sycl::ext::oneapi::this_work_item::get_sub_group();
+ int lane = sg.get_local_linear_id();
+
+ static_assert(
+ (M == 8 && N == 8 && K == 4) || (M == 8 && N == 8 && K == 16) ||
+ (M == 16 && N == 8 && K == 8) || (M == 16 && N == 8 && K == 16) ||
+ (M == 16 && N == 8 && K == 32),
+ "Unsupported MMA shape!");
+
+ short row_load_offset = 4 * (lane >> 2);
+ short col_load_offset = 8 * (lane % 4);
+
+ if constexpr (M == 8 && N == 8 && K == 4) {
+ if constexpr (std::is_floating_point_v<CDType>) {
+ col_load_offset = row_load_offset % 16;
+
+ // Init D matrix with fragments of C matrix
+ *d[0] = c[0];
+ *d[1] = c[1];
+ *d[2] = c[2];
+ *d[3] = c[3];
+ *d[4] = c[4];
+ *d[5] = c[5];
+ *d[6] = c[6];
+ *d[7] = c[7];
+
+ // Calculate the row and col offset indices to iterate through the row
+ // & col fragments of A & B matrices
+ int r_ind = (lane % 2) ? 1 : 0;
+ int c_ind = ((lane % 4) / 2) ? 2 : 0;
+
+ // Each sub-group is responsible for computing a fragment size of 8*8
+ // elements of matrix D for each of 4 MMA computations.
+ // Each work item computes 8 elements of matrix D by gathering
+ // their corresponding col & row matrix fragments of length k (4)
+ // from A & B matrices respectively using below mapping logic:
+ // row0 = (i % 4) if (lane < 16) else (i % 4) + 4
+ // col0 = (lane % 4)
+ // As each row & col fragment of A & B matrices is distributed across
+ // 4 work items, each iteration of below loop loads a partial fragment
+ // of matrix A (row) and matrix B (col) using the row & col offsets.
+ typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
+
+ for (int i = 0; i < 4; i++) {
+ // Load partial fragment from col0 of matrix A ({a0, a1})
+ recv_a[0] =
+ dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
+ // Load partial fragment from col0 of matrix A ({a2, a3})
+ recv_a[1] =
+ dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
+
+ // Load partial fragment from row0 of matrix B ({b0, b1})
+ recv_b[0] =
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
+ // Load partial fragment from row0 of matrix B ({b2, b3})
+ recv_b[1] =
+ dpct::select_from_sub_group(sg, b[1], col_load_offset + i);
+
+ auto ra = reinterpret_cast<ABType*>(recv_a);
+ auto rb = reinterpret_cast<ABType*>(recv_b);
+
+ // Each work item calculates a partial product of A & B matrix
+ // fragments and adds it to the corresponding D matrix fragment (for
+ // even work item indices) d0 += col0{ a0 } * row0{ b0 } d1 += col0{
+ // a0 } * row0{ b1 } d2 += col1{ a2 } * row0{ b0 } d3 += col1{ a2 }
+ // * row0{ b1 } (for odd work item indices) d0 += col0{ a1 } * row0{
+ // b2 } d1 += col0{ a1 } * row0{ b3 } d2 += col1{ a3 } * row0{ b2 }
+ // d3 += col1{ a3 } * row0{ b3 }
+ *d[0] +=
+ static_cast<float>(ra[r_ind]) * static_cast<float>(rb[c_ind]);
+ *d[1] += static_cast<float>(ra[r_ind]) *
+ static_cast<float>(rb[c_ind + 1]);
+ *d[2] += static_cast<float>(ra[r_ind + 2]) *
+ static_cast<float>(rb[c_ind]);
+ *d[3] += static_cast<float>(ra[r_ind + 2]) *
+ static_cast<float>(rb[c_ind + 1]);
+
+ // Load partial fragment from row1 of matrix B ({b0, b1})
+ recv_b[0] =
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 16);
+ // Load partial fragment from row1 of matrix B ({b2, b3})
+ recv_b[1] =
+ dpct::select_from_sub_group(sg, b[1], col_load_offset + i + 16);
+
+ // (for even work item indices)
+ // d0 += col0{ a0 } * row1{ b0 }
+ // d1 += col0{ a0 } * row1{ b1 }
+ // d2 += col1{ a2 } * row1{ b0 }
+ // d3 += col1{ a2 } * row1{ b1 }
+ // (for odd work item indices)
+ // d0 += col0{ a1 } * row1{ b2 }
+ // d1 += col0{ a1 } * row1{ b3 }
+ // d2 += col1{ a3 } * row1{ b2 }
+ // d3 += col1{ a3 } * row1{ b3 }
+ *d[4] +=
+ static_cast<float>(ra[r_ind]) * static_cast<float>(rb[c_ind]);
+ *d[5] += static_cast<float>(ra[r_ind]) *
+ static_cast<float>(rb[c_ind + 1]);
+ *d[6] += static_cast<float>(ra[r_ind + 2]) *
+ static_cast<float>(rb[c_ind]);
+ *d[7] += static_cast<float>(ra[r_ind + 2]) *
+ static_cast<float>(rb[c_ind + 1]);
+ }
+ }
+ } else if constexpr (M == 8 && N == 8 && K == 16) {
+ if constexpr (std::is_integral_v<ABType>) {
+ // Init D matrix with fragments of C matrix
+ *d[0] = c[0];
+ *d[1] = c[1];
+
+ // Each sub-group is responsible for computing a fragment size of 16*8
+ // elements of matrix D.
+ // Each work item computes 2 elements of matrix D by gathering
+ // their corresponding row & col matrix fragments of length k (16)
+ // from A & B matrices respectively using below mapping logic:
+ // row0 = ((lane % 4) * 4) + i
+ // col0 = (lane >> 2)
+ // As each row & col fragment of A & B matrices is distributed across
+ // 4 work items, each iteration of below loop loads a partial fragment
+ // of matrix A (row) and matrix B (col) using the row & col offsets.
+ for (int i = 0; i < 4; i++) {
+ typename MMAType<ABType>::PackType recv_a, recv_b[2];
+
+ // Load partial fragment from row0 of matrix A ({a0, a1, a2, a3})
+ recv_a = dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
+ // Load partial fragment from col0 of matrix B ({b0, b1, b2, b3})
+ recv_b[0] =
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
+ // Load partial fragment from col1 of matrix B ({b0, b1, b2, b3})
+ recv_b[1] =
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
+
+ auto a = reinterpret_cast<ABType*>(&recv_a);
+ auto b = reinterpret_cast<ABType*>(recv_b);
+
+ // Each work item calculates a partial product of A & B matrix
+ // fragments and adds it to the corresponding D matrix fragment d0
+ // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
+ // a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row0{ a0, a1, a2,
+ // a3 } * col0{ b0, b1, b2, b3 } d3 += row0{ a0, a1, a2, a3 } *
+ // col1{ b0, b1, b2, b3 }
+ for (int j = 0; j < 4; j++) {
+ *d[0] += a[j] * b[j];
+ *d[1] += a[j] * b[j + 4];
+ }
+ }
+ }
+ } else if constexpr (M == 16 && N == 8 && K == 8) {
+ if constexpr (std::is_floating_point_v<CDType>) {
+ // Init D matrix fragment with C matrix fragment
+ *d[0] = c[0];
+ *d[1] = c[1];
+ *d[2] = c[2];
+ *d[3] = c[3];
+
+ // Each sub-group is responsible for computing a fragment size of 16*8
+ // elements of matrix D.
+ // Each work item computes 4 elements of matrix D by gathering
+ // their corresponding row & col matrix fragments of length k (8)
+ // from A & B matrices respectively using below mapping logic:
+ // row0 = (lane >> 2) & row1 = (lane >> 2) + 8
+ // col0 = (lane % 4) * 2 + (i & 0x1)
+ // As each row & col fragment of A & B matrices is distributed across
+ // 4 work items, each iteration of below loop loads a partial fragment
+ // of matrix A (row) and matrix B (col) using the row & col offsets.
+ for (int i = 0; i < 4; i++) {
+ typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
+
+ // Load partial fragment from row0 of matrix A ({a0, a1})
+ recv_a[0] =
+ dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
+ // Load partial fragment from row1 of matrix A ({a2, a3})
+ recv_a[1] =
+ dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
+ // Load partial fragment from col0 of matrix B ({b0, b1})
+ recv_b[0] =
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
+ // Load partial fragment from col1 of matrix B ({b0, b1})
+ recv_b[1] =
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
+
+ auto ra = reinterpret_cast<ABType*>(recv_a);
+ auto rb = reinterpret_cast<ABType*>(recv_b);
+
+ // Each work item calculates a partial product of A & B matrix
+ // fragments and adds it to the corresponding D matrix fragment d0
+ // += row0{ a0, a1 } * col0{ b0, b1 } d1 += row0{ a0, a1 } * col1{
+ // b0, b1 } d2 += row1{ a2, a3 } * col0{ b0, b1 } d3 += row1{ a2, a3
+ // } * col1{ b0, b1 }
+ for (int j = 0; j < 2; j++) {
+ *d[0] += static_cast<float>(ra[j]) * static_cast<float>(rb[j]);
+ *d[1] +=
+ static_cast<float>(ra[j]) * static_cast<float>(rb[j + 2]);
+ *d[2] +=
+ static_cast<float>(ra[j + 2]) * static_cast<float>(rb[j]);
+ *d[3] +=
+ static_cast<float>(ra[j + 2]) * static_cast<float>(rb[j + 2]);
+ }
+ }
+ }
+ } else if constexpr (M == 16 && N == 8 && K == 16) {
+ if constexpr (std::is_floating_point_v<CDType>) {
+ // Init D matrix fragment with C matrix fragment
+ *d[0] = c[0];
+ *d[1] = c[1];
+ *d[2] = c[2];
+ *d[3] = c[3];
+
+ // Each sub-group is responsible for computing a fragment size of 16*8
+ // elements of matrix D.
+ // Each work item computes 4 elements of matrix D by gathering
+ // their corresponding row & col matrix fragments of length k (8)
+ // from A & B matrices respectively using below mapping logic:
+ // row0 = (lane >> 2) & row1 = (lane >> 2) + 8
+ // col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1
+ // As each row & col fragment of A & B matrices is distributed across
+ // 4 work items, each iteration of below loop loads a partial fragment
+ // of matrix A (row) and matrix B (col) using the row & col offsets.
+ for (int i = 0; i < 4; i++) {
+ typename MMAType<ABType>::PackType recv_a[4], recv_b[4];
+
+ // Load partial fragment from row0 of matrix A ({a0, a1})
+ recv_a[0] =
+ dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
+ // Load partial fragment from row0 of matrix A ({a2, a3})
+ recv_a[1] =
+ dpct::select_from_sub_group(sg, a[2], row_load_offset + i);
+ // Load partial fragment from row1 of matrix A ({a0, a1})
+ recv_a[2] =
+ dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
+ // Load partial fragment from row1 of matrix A ({a2, a3})
+ recv_a[3] =
+ dpct::select_from_sub_group(sg, a[3], row_load_offset + i);
+
+ // Load partial fragment from col0 of matrix B ({b0, b1})
+ recv_b[0] =
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
+ // Load partial fragment from col0 of matrix B ({b2, b3})
+ recv_b[1] =
+ dpct::select_from_sub_group(sg, b[1], col_load_offset + i);
+ // Load partial fragment from col1 of matrix B ({b0, b1})
+ recv_b[2] =
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + 4 + i);
+ // Load partial fragment from col1 of matrix B ({b2, b3})
+ recv_b[3] =
+ dpct::select_from_sub_group(sg, b[1], col_load_offset + 4 + i);
+
+ auto ra = reinterpret_cast<ABType*>(recv_a);
+ auto rb = reinterpret_cast<ABType*>(recv_b);
+
+ // Each work item calculates a partial product of A & B matrix
+ // fragments and adds it to the corresponding D matrix fragment d0
+ // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
+ // a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row1{ a0, a1, a2,
+ // a3 } * col0{ b0, b1, b2, b3 } d3 += row1{ a0, a1, a2, a3 } *
+ // col1{ b0, b1, b2, b3 }
+ for (int j = 0; j < 4; j++) {
+ *d[0] += static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[j]);
+ *d[1] +=
+ static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[j + 4]);
+ *d[2] +=
+ static_cast<CDType>(ra[j + 4]) * static_cast<CDType>(rb[j]);
+ *d[3] += static_cast<CDType>(ra[j + 4]) *
+ static_cast<CDType>(rb[j + 4]);
+ }
+ }
+ } else if constexpr (std::is_integral_v<ABType>) {
+ // Init D matrix with fragments of C matrix
+ *d[0] = c[0];
+ *d[1] = c[1];
+ *d[2] = c[2];
+ *d[3] = c[3];
+
+ // Each sub-group is responsible for computing a fragment size of 16*8
+ // elements of matrix D.
+ // Each work item computes 4 elements of matrix D by gathering
+ // their corresponding row & col matrix fragments of length k (8)
+ // from A & B matrices respectively using below mapping logic:
+ // row0 = (lane >> 2) & row1 = (lane >> 2) + 8
+ // col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1
+ // As each row & col fragment of A & B matrices is distributed across
+ // 4 work items, each iteration of below loop loads a partial fragment
+ // of matrix A (row) and matrix B (col) using the row & col offsets.
+ for (int i = 0; i < 4; i++) {
+ typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
+
+ // Load partial fragment from row0 of matrix A ({a0, a1, a2, a3})
+ recv_a[0] =
+ dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
+ // Load partial fragment from row1 of matrix A ({a4, a5, a6, a7})
+ recv_a[1] =
+ dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
+ // Load partial fragment from col0 of matrix B ({b0, b1, b2, b3})
+ recv_b[0] =
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
+ // Load partial fragment from col1 of matrix B ({b4, b5, b6, b7})
+ recv_b[1] =
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
+
+ auto ra = reinterpret_cast<ABType*>(recv_a);
+ auto rb = reinterpret_cast<ABType*>(recv_b);
+
+ // Each work item calculates a partial product of A & B matrix
+ // fragments and adds it to the corresponding D matrix fragment d0
+ // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
+ // a0, a1, a2, a3 } * col1{ b4, b5, b6, b7 } d2 += row1{ a4, a5, a6,
+ // a7 } * col0{ b0, b1, b2, b3 } d3 += row1{ a4, a5, a6, a7 } *
+ // col1{ b4, b5, b6, b7 }
+ for (int i = 0; i < 4; i++) {
+ *d[0] += ra[i] * rb[i];
+ *d[1] += ra[i] * rb[i + 4];
+ *d[2] += ra[i + 4] * rb[i];
+ *d[3] += ra[i + 4] * rb[i + 4];
+ }
+ }
+ }
+ } else if constexpr (M == 16 && N == 8 && K == 32) {
+ if constexpr (std::is_integral_v<ABType>) {
+ // Init D matrix with fragments of C matrix
+ *d[0] = c[0];
+ *d[1] = c[1];
+ *d[2] = c[2];
+ *d[3] = c[3];
+
+ // Each sub-group is responsible for computing a fragment size of 16*8
+ // elements of matrix D.
+ // Each work item computes 4 elements of matrix D by gathering
+ // their corresponding row & col matrix fragments of length k (32)
+ // from A & B matrices respectively using below mapping logic:
+ // row0 = (lane >> 2) & row1 = (lane >> 2) + 8
+ // col0 = ((lane % 4) * 4) + (i & 0x3) & col1 = ((lane % 4) * 4) + (i
+ // & 0x3) As each row & col fragment of A & B matrices is distributed
+ // across 4 work items, each iteration of below loop loads a partial
+ // fragment of matrix A (row) and matrix B (col) using the row & col
+ // offsets.
+ for (int i = 0; i < 4; i++) {
+ typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
+
+ // Load partial fragment from row0 of matrix A ({a0, a1, a2, a3})
+ recv_a[0] =
+ dpct::select_from_sub_group(sg, a[0], row_load_offset + i);
+ // Load partial fragment from row1 of matrix A ({a4, a5, a6, a7})
+ recv_a[1] =
+ dpct::select_from_sub_group(sg, a[1], row_load_offset + i);
+ // Load partial fragment from col0 of matrix B ({b0, b1, b2, b3})
+ recv_b[0] =
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i);
+ // Load partial fragment from col1 of matrix B ({b0, b1, b2, b3})
+ recv_b[1] =
+ dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4);
+
+ auto a = reinterpret_cast<ABType*>(recv_a);
+ auto b = reinterpret_cast<ABType*>(recv_b);
+
+ // Each work item calculates a partial product of A & B matrix
+ // fragments and adds it to the corresponding D matrix fragment d0
+ // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{
+ // a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row1{ a4, a5, a6,
+ // a7 } * col0{ b0, b1, b2, b3 } d3 += row1{ a4, a5, a6, a7 } *
+ // col1{ b0, b1, b2, b3 }
+ for (int j = 0; j < 4; j++) {
+ *d[0] += a[j] * b[j];
+ *d[1] += a[j] * b[j + 4];
+ *d[2] += a[j + 4] * b[j];
+ *d[3] += a[j + 4] * b[j + 4];
+ }
+ }
+
+ for (int i = 0; i < 4; i++) {
+ typename MMAType<ABType>::PackType recv_a[2], recv_b[2];
+
+ // Load partial fragment from row0 of matrix A ({a8, a9, a10, a11})
+ recv_a[0] =
+ dpct::select_from_sub_group(sg, a[2], row_load_offset + i);
+ // Load partial fragment from row1 of matrix A ({a12, a13, a14,
+ // a15})
+ recv_a[1] =
+ dpct::select_from_sub_group(sg, a[3], row_load_offset + i);
+ // Load partial fragment from col0 of matrix B ({b4, b5, b6, b7})
+ recv_b[0] =
+ dpct::select_from_sub_group(sg, b[1], col_load_offset + i);
+ // Load partial fragment from col1 of matrix B ({b4, b5, b6, b7})
+ recv_b[1] =
+ dpct::select_from_sub_group(sg, b[1], col_load_offset + i + 4);
+
+ auto a = reinterpret_cast<ABType*>(recv_a);
+ auto b = reinterpret_cast<ABType*>(recv_b);
+
+ // Each work item calculates a partial product of A & B matrix
+ // fragments and adds it to the corresponding D matrix fragment d0
+ // += row0{ a8, a9, a10, a11 } * col0{ b4, b5, b6, b7 } d1 += row0{
+ // a8, a9, a10, a11 } * col1{ b4, b5, b6, b7 } d2 += row1{ a12, a13,
+ // a14, a15 } * col0{ b4, b5, b6, b7 } d3 += row1{ a12, a13, a14,
+ // a15 } * col1{ b4, b5, b6, b7 }
+ for (int j = 0; j < 4; j++) {
+ *d[0] += a[j] * b[j];
+ *d[1] += a[j] * b[j + 4];
+ *d[2] += a[j + 4] * b[j];
+ *d[3] += a[j + 4] * b[j + 4];
+ }
+ }
+ }
+ }
+ }
} // COPY from DPCT head files
#endif // GGML_SYCL_DPCT_HELPER_HPP
--- /dev/null
+#pragma once
+
+#include <sycl/sycl.hpp>
+#include "dpct/helper.hpp"
+#include "common.hpp"
+#include "convert.hpp"
+#include "vecdotq.hpp"
+
+#include "ggml.h"
+
+#include <cstdint>
+#include <cmath>
+#include <float.h>
+
+
+#define FATTN_KQ_STRIDE 256
+#define HALF_MAX_HALF sycl::half(65504.0f/2) // Use neg. of this instead of -INFINITY to initialize KQ max vals to avoid NaN upon subtraction.
+#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
+#define FATTN_KQ_MAX_OFFSET (3.0f*0.6931f)
+
+typedef void (*fattn_kernel_t)(
+ const char* Q,
+ const char* K,
+ const char* V,
+ const char* mask,
+ const char* sinks,
+ const int* KV_max,
+ float* dst,
+ sycl::float2* dst_meta,
+ const float scale,
+ const float max_bias,
+ const float m0,
+ const float m1,
+ const uint32_t n_head_log2,
+ const float logit_softcap,
+ const int32_t ne00,
+ const sycl::uint3 ne01,
+ const int32_t ne02,
+ const int32_t ne03,
+ const int32_t nb01,
+ const int32_t nb02,
+ const int32_t nb03,
+ const int32_t ne10,
+ const int32_t ne11,
+ const int32_t ne12,
+ const int32_t ne13,
+ const int32_t nb11,
+ const int32_t nb12,
+ const int64_t nb13,
+ const int32_t nb21,
+ const int32_t nb22,
+ const int64_t nb23,
+ const int32_t ne31,
+ const int32_t ne32,
+ const int32_t ne33,
+ const int32_t nb31,
+ const int32_t nb32,
+ const int64_t nb33);
+
+typedef float (*vec_dot_KQ_t)(
+ const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds);
+
+template <int D, int nthreads>
+static __dpct_inline__ float vec_dot_fattn_vec_KQ_f16(const char * __restrict__ K_c,
+ const void * __restrict__ Q_v,
+ const int * __restrict__ Q_q8,
+ const void * __restrict__ Q_ds_v) {
+ const sycl::half2 * K_h2 = (const sycl::half2 *) K_c;
+ GGML_UNUSED(Q_q8);
+ GGML_UNUSED(Q_ds_v);
+
+ constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();
+ constexpr int cpy_ne = cpy_nb / 4;
+
+ float sum = 0.0f;
+
+#pragma unroll
+ for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) {
+ sycl::half2 tmp[cpy_ne];
+ ggml_sycl_memcpy_1<sizeof(tmp)>(
+ tmp,
+ K_h2 + k_KQ_0 + (sycl::ext::oneapi::this_work_item::get_nd_item<3>().get_local_id(2) % nthreads) * cpy_ne);
+#pragma unroll
+ for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
+#ifdef GGML_SYCL_F16
+ ggml_sycl_mad(sum, tmp[k_KQ_1] , ((const sycl::half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
+#else
+ ggml_sycl_mad(sum, __half22float2(tmp[k_KQ_1]), ((const sycl::float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
+#endif // GGML_SYCL_F16
+ }
+ }
+
+ return sum;
+}
+
+template <int D, int nthreads, int warp_size>
+static __dpct_inline__ float vec_dot_fattn_vec_KQ_q4_0(const char * __restrict__ K_c,
+ const void * __restrict__ Q_v,
+ const int * __restrict__ Q_q8,
+ const void * __restrict__ Q_ds_v) {
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+
+ const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
+ GGML_UNUSED(Q_v);
+
+ float sum = 0.0f;
+
+#pragma unroll
+ for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
+ const int k_KQ =
+ k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads);
+
+ const int ib = k_KQ / QI8_1;
+ const int iqs4 = k_KQ % QI4_0;
+ const int shift = k_KQ & (QI8_1/2);
+
+ int v;
+ ggml_sycl_memcpy_1<sizeof(int), 2>(&v, K_q4_0[ib].qs + sizeof(int)*iqs4);
+ v = (v >> shift) & 0x0F0F0F0F;
+ const int u = Q_q8[k_KQ_0/nthreads];
+
+ const int sumi = ggml_sycl_dp4a(v, u, 0);
+
+ const sycl::float2 Q_ds = ((const sycl::float2 *) Q_ds_v)[k_KQ_0 / nthreads];
+ sum += __half2float(K_q4_0[ib].d) * (sumi*Q_ds.x() - (8/QI8_1)*Q_ds.y());
+ }
+
+ return sum;
+}
+
+template <int D, int nthreads , int warp_size>
+static __dpct_inline__ float vec_dot_fattn_vec_KQ_q4_1(const char * __restrict__ K_c,
+ const void * __restrict__ Q_v,
+ const int * __restrict__ Q_q8,
+ const void * __restrict__ Q_ds_v) {
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+ const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
+ GGML_UNUSED(Q_v);
+
+ float sum = 0.0f;
+
+#pragma unroll
+ for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
+ const int k_KQ =
+ k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads);
+
+ const int ib = k_KQ / QI8_1;
+ const int iqs4 = k_KQ % QI4_1;
+ const int shift = k_KQ & (QI8_1/2);
+
+ int v;
+ ggml_sycl_memcpy_1<sizeof(int)>(&v, K_q4_1[ib].qs + sizeof(int)*iqs4);
+ v = (v >> shift) & 0x0F0F0F0F;
+ const int u = Q_q8[k_KQ_0/nthreads];
+
+ const int sumi = ggml_sycl_dp4a(v, u, 0);
+
+ const sycl::float2 K_dm = (K_q4_1[ib].dm).template convert<float, sycl::rounding_mode::automatic>();
+ const sycl::float2 Q_ds = ((const sycl::float2 *) Q_ds_v)[k_KQ_0 / nthreads];
+
+ sum += K_dm.x()*Q_ds.x()*sumi + K_dm.y()*Q_ds.y()/QI8_1;
+ }
+
+ return sum;
+}
+
+template <int D, int nthreads, int warp_size>
+static __dpct_inline__ float vec_dot_fattn_vec_KQ_q5_0(const char * __restrict__ K_c,
+ const void * __restrict__ Q_v,
+ const int * __restrict__ Q_q8,
+ const void * __restrict__ Q_ds_v) {
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+ const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
+ GGML_UNUSED(Q_v);
+
+ float sum = 0.0f;
+
+#pragma unroll
+ for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
+ const int k_KQ =
+ k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads);
+
+ const int ib = k_KQ / QI8_1;
+ const int iqs4 = k_KQ % QI5_0;
+ const int iqs8 = k_KQ % QI8_1;
+ const int shift = k_KQ & (QI8_1/2);
+
+ int v;
+ ggml_sycl_memcpy_1<sizeof(int), 2>(&v, K_q5_0[ib].qs + sizeof(int)*iqs4);
+ v = (v >> shift) & 0x0F0F0F0F;
+
+ {
+ int vh;
+ ggml_sycl_memcpy_1<sizeof(int), 2>(&vh, K_q5_0[ib].qh);
+ vh >>= iqs8 * QI5_0;
+
+ v |= (vh << 4) & 0x00000010; // 0 -> 4
+ v |= (vh << 11) & 0x00001000; // 1 -> 12
+ v |= (vh << 18) & 0x00100000; // 2 -> 20
+ v |= (vh << 25) & 0x10000000; // 3 -> 28
+ }
+
+ const int u = Q_q8[k_KQ_0/nthreads];
+
+ const int sumi = ggml_sycl_dp4a(v, u, 0);
+
+ const sycl::float2 Q_ds = ((const sycl::float2 *) Q_ds_v)[k_KQ_0 / nthreads];
+
+ sum += __half2float(K_q5_0[ib].d) * (sumi*Q_ds.x() - (16/QI8_1)*Q_ds.y());
+ }
+
+ return sum;
+}
+
+template <int D, int nthreads, int warp_size>
+static __dpct_inline__ float vec_dot_fattn_vec_KQ_q5_1(const char * __restrict__ K_c,
+ const void * __restrict__ Q_v,
+ const int * __restrict__ Q_q8,
+ const void * __restrict__ Q_ds_v) {
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+ const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
+ GGML_UNUSED(Q_v);
+
+ float sum = 0.0f;
+
+#pragma unroll
+ for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
+ const int k_KQ =
+ k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads);
+
+ const int ib = k_KQ / QI8_1;
+ const int iqs4 = k_KQ % QI5_1;
+ const int iqs8 = k_KQ % QI8_1;
+ const int shift = k_KQ & (QI8_1/2);
+
+ int v;
+ ggml_sycl_memcpy_1<sizeof(int)>(&v, K_q5_1[ib].qs + sizeof(int)*iqs4);
+ v = (v >> shift) & 0x0F0F0F0F;
+
+ {
+ int vh;
+ ggml_sycl_memcpy_1<sizeof(int)>(&vh, K_q5_1[ib].qh);
+ vh >>= iqs8 * QI5_0;
+
+ v |= (vh << 4) & 0x00000010; // 0 -> 4
+ v |= (vh << 11) & 0x00001000; // 1 -> 12
+ v |= (vh << 18) & 0x00100000; // 2 -> 20
+ v |= (vh << 25) & 0x10000000; // 3 -> 28
+ }
+
+ const int u = Q_q8[k_KQ_0/nthreads];
+
+ const int sumi = ggml_sycl_dp4a(v, u, 0);
+
+ const sycl::float2 K_dm = (K_q5_1[ib].dm).template convert<float, sycl::rounding_mode::automatic>();
+ const sycl::float2 Q_ds = ((const sycl::float2 *) Q_ds_v)[k_KQ_0 / nthreads];
+
+ sum += K_dm.x()*Q_ds.x()*sumi + K_dm.y()*Q_ds.y()/QI8_1;
+ }
+
+ return sum;
+}
+
+template <int D, int nthreads, int warp_size>
+static __dpct_inline__ float vec_dot_fattn_vec_KQ_q8_0(const char * __restrict__ K_c,
+ const void * __restrict__ Q_v,
+ const int * __restrict__ Q_q8,
+ const void * __restrict__ Q_ds_v) {
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+ const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
+ GGML_UNUSED(Q_v);
+
+ float sum = 0.0f;
+
+#pragma unroll
+ for (int k_KQ_0 = 0; k_KQ_0 < int(D/sizeof(int)); k_KQ_0 += nthreads) {
+ const int k_KQ =
+ k_KQ_0 + (nthreads == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads);
+
+ const int ib = k_KQ / QI8_0;
+ const int iqs = k_KQ % QI8_0;
+
+ int v;
+ ggml_sycl_memcpy_1<sizeof(v), 2>(&v, K_q8_0[ib].qs + 4*iqs);
+
+ const sycl::float2 * Q_ds = (const sycl::float2 *) Q_ds_v;
+ const float Q_d = Q_ds[k_KQ_0 / nthreads].x();
+
+ sum += vec_dot_q8_0_q8_1_impl<float, 1>(&v, &Q_q8[k_KQ_0/nthreads], K_q8_0[ib].d, Q_d);
+ }
+
+ return sum;
+}
+
+template <typename Tds, int ni, int warp_size>
+static __dpct_inline__ void quantize_q8_1_to_shared(const float * __restrict__ x,
+ const float scale,
+ int * __restrict__ yq32,
+ void * __restrict__ yds) {
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+
+ float vals[sizeof(int)] = { 0.0f };
+#pragma unroll
+ for (int l = 0; l < int(sizeof(int)); ++l) {
+ vals[l] =
+ (ni == warp_size || item_ct1.get_local_id(2) < ni) ? scale * x[4 * item_ct1.get_local_id(2) + l] : 0.0f;
+ }
+
+ float amax = sycl::fabs(vals[0]);
+ float sum = vals[0];
+#pragma unroll
+ for (int l = 1; l < int(sizeof(int)); ++l) {
+ amax = sycl::fmax(amax, sycl::fabs(vals[l]));
+ sum += vals[l];
+ }
+#pragma unroll
+ for (int mask = QI8_1/2; mask > 0; mask >>= 1) {
+ amax = sycl::fmax(
+ amax, dpct::permute_sub_group_by_xor(sycl::ext::oneapi::this_work_item::get_sub_group(), amax, mask));
+ sum += dpct::permute_sub_group_by_xor(sycl::ext::oneapi::this_work_item::get_sub_group(), sum, mask);
+ }
+
+ const float d = amax / 127;
+ int q32 = 0;
+ int8_t * q8 = (int8_t *) &q32;
+
+ if (d != 0.0f) {
+#pragma unroll
+ for (int l = 0; l < int(sizeof(int)); ++l) {
+ q8[l] = sycl::round(vals[l] / d);
+ }
+ }
+
+ yq32[item_ct1.get_local_id(2)] = q32;
+ if (item_ct1.get_local_id(2) % QI8_1 == 0 && (ni == warp_size || item_ct1.get_local_id(2) < ni)) {
+ if (std::is_same<Tds, sycl::half2>::value) {
+ ((sycl::half2 *) yds)[item_ct1.get_local_id(2)/QI8_1] = make_half2(d, sum);
+ } else {
+ ((sycl::float2 *) yds)[item_ct1.get_local_id(2)/QI8_1] = make_float2(d, sum);
+ }
+ }
+}
+
+typedef void (*dequantize_V_t)(const void *, void *, const int64_t);
+
+template <typename T, int ne>
+static __dpct_inline__ void dequantize_V_f16(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
+ if constexpr (std::is_same_v<T, sycl::half>) {
+ ggml_sycl_memcpy_1<ne * sizeof(sycl::half)>(dst, (const sycl::half *) vx + i0);
+ } else if constexpr (std::is_same_v<T, float>) {
+ static_assert(ne % 2 == 0, "bad ne");
+ sycl::half2 tmp[ne / 2];
+ ggml_sycl_memcpy_1<ne * sizeof(sycl::half)>(tmp, (const sycl::half *) vx + i0);
+ sycl::float2 * dst_f2 = (sycl::float2 *) dst;
+#pragma unroll
+ for (int l = 0; l < ne/2; ++l) {
+ dst_f2[l] = tmp[l].template convert<float, sycl::rounding_mode::automatic>();
+ }
+ } else {
+ static_assert(std::is_same_v<T, void>, "unsupported type");
+ }
+}
+
+template <typename T, int ne>
+static __dpct_inline__ void dequantize_V_q4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
+ const block_q4_0 * x = (const block_q4_0 *) vx;
+
+ const int64_t ib = i0 / QK4_0;
+ const int iqs = i0 % (QK4_0/2);
+ const int shift = (i0 % QK4_0) / (QK4_0/2);
+
+ int q;
+ static_assert(ne == 2 || ne == 4, "bad ne");
+ ggml_sycl_memcpy_1<ne, 2>(&q, x[ib].qs + iqs);
+ q >>= 4*shift;
+ q &= 0x0F0F0F0F;
+ q = dpct::vectorized_binary<sycl::char4>(q, 0x08080808, dpct::sub_sat());
+
+ const int8_t * q8 = (const int8_t *) &q;
+
+#ifdef GGML_SYCL_F16
+ if constexpr (std::is_same_v<T, sycl::half>) {
+ const sycl::half2 d = sycl::half2(x[ib].d);
+
+#pragma unroll
+ for (int l0 = 0; l0 < ne; l0 += 2) {
+ ((sycl::half2 *) dst)[l0 / 2] = d * sycl::half2(q8[l0 + 0], q8[l0 + 1]);
+ }
+ } else
+#endif // GGML_SYCL_F16
+ if constexpr (std::is_same_v<T, float>) {
+ const float d = x[ib].d;
+
+#pragma unroll
+ for (int l = 0; l < ne; ++l) {
+ ((float *) dst)[l] = d * q8[l];
+ }
+ } else {
+ static_assert(std::is_same_v<T, void>, "bad type");
+ }
+}
+
+template <typename T, int ne>
+static __dpct_inline__ void dequantize_V_q4_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
+ const block_q4_1 * x = (const block_q4_1 *) vx;
+
+ const int64_t ib = i0 / QK4_1;
+ const int iqs = i0 % (QK4_1/2);
+ const int shift = (i0 % QK4_1) / (QK4_1/2);
+
+ int q;
+ static_assert(ne == 2 || ne == 4, "bad ne");
+ ggml_sycl_memcpy_1<ne>(&q, x[ib].qs + iqs);
+ q >>= 4*shift;
+ q &= 0x0F0F0F0F;
+
+ const int8_t * q8 = (const int8_t *) &q;
+
+#ifdef GGML_SYCL_F16
+ if constexpr (std::is_same_v<T, sycl::half>) {
+ const sycl::half2 dm = x[ib].dm;
+ const sycl::half2 d = sycl::half2(dm[0]);
+ const sycl::half2 m = sycl::half2(dm[1]);
+
+#pragma unroll
+ for (int l0 = 0; l0 < ne; l0 += 2) {
+ ((sycl::half2 *) dst)[l0 / 2] = d * sycl::half2(q8[l0 + 0], q8[l0 + 1]) + m;
+ }
+ } else
+#endif // GGML_SYCL_F16
+ if constexpr (std::is_same_v<T, float>) {
+ const sycl::float2 dm = (x[ib].dm).template convert<float, sycl::rounding_mode::automatic>();
+
+#pragma unroll
+ for (int l = 0; l < ne; ++l) {
+ ((float *) dst)[l] = dm.x() * q8[l] + dm.y();
+ }
+ } else {
+ static_assert(std::is_same_v<T, void>, "bad type");
+ }
+}
+
+template <typename T, int ne>
+static __dpct_inline__ void dequantize_V_q5_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
+ const block_q5_0 * x = (const block_q5_0 *) vx;
+
+ const int64_t ib = i0 / QK5_0;
+ const int idq = i0 % QK5_0;
+ const int iqs = i0 % (QK5_0/2);
+ const int shift = (i0 % QK5_0) / (QK5_0/2);
+
+ int q;
+ static_assert(ne == 2 || ne == 4, "bad ne");
+ ggml_sycl_memcpy_1<ne, 2>(&q, x[ib].qs + iqs);
+ q >>= 4*shift;
+ q &= 0x0F0F0F0F;
+
+ {
+ int qh;
+ ggml_sycl_memcpy_1<ne, 2>(&qh, x[ib].qh);
+#pragma unroll
+ for (int l = 0; l < ne; ++l) {
+ q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4);
+ }
+ }
+
+ q = dpct::vectorized_binary<sycl::char4>(q, 0x10101010, dpct::sub_sat());
+
+ const int8_t * q8 = (const int8_t *) &q;
+
+#ifdef GGML_SYCL_F16
+ if constexpr (std::is_same_v<T, sycl::half>) {
+ const sycl::half2 d = sycl::half2(x[ib].d);
+
+#pragma unroll
+ for (int l0 = 0; l0 < ne; l0 += 2) {
+ ((sycl::half2 *) dst)[l0 / 2] = d * sycl::half2(q8[l0 + 0], q8[l0 + 1]);
+ }
+ } else
+#endif // GGML_SYCL_F16
+ if constexpr (std::is_same_v<T, float>) {
+ const float d = x[ib].d;
+
+#pragma unroll
+ for (int l = 0; l < ne; ++l) {
+ ((float *) dst)[l] = d * q8[l];
+ }
+ } else {
+ static_assert(std::is_same_v<T, void>, "bad type");
+ }
+}
+
+template <typename T, int ne>
+static __dpct_inline__ void dequantize_V_q5_1(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
+ const block_q5_1 * x = (const block_q5_1 *) vx;
+
+ const int64_t ib = i0 / QK5_1;
+ const int idq = i0 % QK5_1;
+ const int iqs = i0 % (QK5_1/2);
+ const int shift = (i0 % QK5_1) / (QK5_1/2);
+
+ int q;
+ static_assert(ne == 2 || ne == 4, "bad ne");
+ ggml_sycl_memcpy_1<ne>(&q, x[ib].qs + iqs);
+ q >>= 4*shift;
+ q &= 0x0F0F0F0F;
+
+ {
+ int qh;
+ ggml_sycl_memcpy_1<ne>(&qh, x[ib].qh);
+#pragma unroll
+ for (int l = 0; l < ne; ++l) {
+ q |= ((qh >> (idq + l)) & 0x00000001) << (8*l + 4);
+ }
+ }
+
+ const int8_t * q8 = (const int8_t *) &q;
+
+#ifdef GGML_SYCL_F16
+ if constexpr (std::is_same_v<T, sycl::half>) {
+ const sycl::half2 dm = x[ib].dm;
+ const sycl::half2 d = sycl::half2(dm[0]);
+ const sycl::half2 m = sycl::half2(dm[1]);
+
+#pragma unroll
+ for (int l0 = 0; l0 < ne; l0 += 2) {
+ ((sycl::half2 *) dst)[l0 / 2] = d * sycl::half2(q8[l0 + 0], q8[l0 + 1]) + m;
+ }
+ } else
+#endif // GGML_SYCL_F16
+ if constexpr (std::is_same_v<T, float>) {
+ const sycl::float2 dm = (x[ib].dm).template convert<float, sycl::rounding_mode::automatic>();
+
+#pragma unroll
+ for (int l = 0; l < ne; ++l) {
+ ((float *) dst)[l] = dm.x() * q8[l] + dm.y();
+ }
+ } else {
+ static_assert(std::is_same_v<T, void>, "bad type");
+ }
+}
+
+template <typename T, int ne>
+static __dpct_inline__ void dequantize_V_q8_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) {
+ const block_q8_0 * x = (const block_q8_0 *) vx;
+
+ const int64_t ib = i0 / QK8_0;
+ const int iqs = i0 % QK8_0;
+
+ static_assert(ne % 2 == 0, "bad ne");
+ int8_t qs[ne];
+ ggml_sycl_memcpy_1<ne, 2>(qs, x[ib].qs + iqs);
+
+#ifdef GGML_SYCL_F16
+ if constexpr (std::is_same<T, sycl::half>::value) {
+ const sycl::half2 d = sycl::half2(x[ib].d);
+
+#pragma unroll
+ for (int l0 = 0; l0 < ne; l0 += 2) {
+ ((sycl::half2 *) dst)[l0 / 2] = d * make_half2(qs[l0 + 0], qs[l0 + 1]);
+ }
+ } else
+#endif // GGML_SYCL_F16
+ if constexpr (std::is_same<T, float>::value) {
+ const float d = x[ib].d;
+
+#pragma unroll
+ for (int l = 0; l < ne; ++l) {
+ ((float *) dst)[l] = d * qs[l];
+ }
+ } else {
+ static_assert(std::is_same_v<T, void>, "unsupported type");
+ }
+}
+
+template <int type_K, int D, int nthreads, int warp_size>
+constexpr vec_dot_KQ_t get_vec_dot_KQ() {
+ if constexpr (type_K == GGML_TYPE_F16) {
+ return vec_dot_fattn_vec_KQ_f16<D, nthreads>;
+ } else if constexpr (type_K == GGML_TYPE_Q4_0) {
+ return vec_dot_fattn_vec_KQ_q4_0<D, nthreads, warp_size>;
+ } else if constexpr (type_K == GGML_TYPE_Q4_1) {
+ return vec_dot_fattn_vec_KQ_q4_1<D, nthreads, warp_size>;
+ } else if constexpr (type_K == GGML_TYPE_Q5_0) {
+ return vec_dot_fattn_vec_KQ_q5_0<D, nthreads, warp_size>;
+ } else if constexpr (type_K == GGML_TYPE_Q5_1) {
+ return vec_dot_fattn_vec_KQ_q5_1<D, nthreads, warp_size>;
+ } else if constexpr (type_K == GGML_TYPE_Q8_0) {
+ return vec_dot_fattn_vec_KQ_q8_0<D, nthreads, warp_size>;
+ } else {
+ static_assert(type_K == -1, "bad type");
+ return nullptr;
+ }
+}
+
+template <int type_V, typename T, int ne>
+constexpr dequantize_V_t get_dequantize_V() {
+ if constexpr (type_V == GGML_TYPE_F16) {
+ return dequantize_V_f16<T, ne>;
+ } else if constexpr (type_V == GGML_TYPE_Q4_0) {
+ return dequantize_V_q4_0<T, ne>;
+ } else if constexpr (type_V == GGML_TYPE_Q4_1) {
+ return dequantize_V_q4_1<T, ne>;
+ } else if constexpr (type_V == GGML_TYPE_Q5_0) {
+ return dequantize_V_q5_0<T, ne>;
+ } else if constexpr (type_V == GGML_TYPE_Q5_1) {
+ return dequantize_V_q5_1<T, ne>;
+ } else if constexpr (type_V == GGML_TYPE_Q8_0) {
+ return dequantize_V_q8_0<T, ne>;
+ } else {
+ static_assert(type_V == -1, "bad type");
+ return nullptr;
+ }
+}
+
+template <int ncols1, int warp_size>
+static void flash_attn_mask_to_KV_max(const sycl::half2 * __restrict__ mask,
+ int * __restrict__ KV_max,
+ const int ne30,
+ const int s31,
+ const int s33,
+ int * buf_iw) {
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+ const int ne31 = item_ct1.get_group_range(2);
+ const int tid = item_ct1.get_local_id(2);
+ const int sequence = item_ct1.get_group(1);
+ const int jt = item_ct1.get_group(2);
+
+ mask += sequence*s33 + jt*ncols1*s31;
+
+ if (tid < warp_size) {
+ buf_iw[tid] = 1;
+ }
+ item_ct1.barrier(sycl::access::fence_space::local_space);
+
+ int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE;
+ for (; KV_max_sj >= 0; KV_max_sj -= FATTN_KQ_STRIDE) {
+ int all_inf = 1;
+
+#pragma unroll
+ for (int j = 0; j < ncols1; ++j) {
+ const sycl::float2 tmp =
+ mask[j * s31 + KV_max_sj / 2 + tid].template convert<float, sycl::rounding_mode::automatic>();
+ all_inf = all_inf && int(sycl::isinf((float) (tmp.x()))) && int(sycl::isinf((float) (tmp.y())));
+ }
+
+ all_inf = warp_reduce_all<warp_size>(all_inf);
+ if (tid % warp_size == 0) {
+ buf_iw[tid / warp_size] = all_inf;
+ }
+ item_ct1.barrier(sycl::access::fence_space::local_space);
+ all_inf = buf_iw[tid % warp_size];
+ item_ct1.barrier(sycl::access::fence_space::local_space);
+ all_inf = warp_reduce_all<warp_size>(all_inf);
+
+ if (!all_inf) {
+ break;
+ }
+ }
+
+ // If the break in the loop was not triggered, KV_max_sj is now -FATTN_KQ_STRIDE.
+ // If the break was triggered it's the lower edge of the tile with the first non-masked values.
+ // In either case, walk back the decrementation by FATTN_KQ_STRIDE.
+ KV_max_sj += FATTN_KQ_STRIDE;
+
+ if (item_ct1.get_local_id(2) != 0) {
+ return;
+ }
+
+ KV_max[sequence*ne31 + jt] = KV_max_sj;
+}
+
+template <int D, int ncols1, int ncols2> // D == head size
+
+static void flash_attn_stream_k_fixup(float * __restrict__ dst,
+ const sycl::float2 * __restrict__ dst_fixup,
+ const int ne01,
+ const int ne02,
+ const int ne03,
+ const int ne11,
+ const int ne12,
+ const int nbatch_fa) {
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+ constexpr int ncols = ncols1 * ncols2;
+
+ const int bidx0 = item_ct1.get_group(2);
+ const int j = item_ct1.get_group(1);
+ const int c = item_ct1.get_group(0);
+ const int jc = j*ncols2 + c;
+ const int tid = item_ct1.get_local_id(2);
+
+ const float * dst_fixup_data = ((const float *) dst_fixup) + item_ct1.get_group_range(2) * (2 * 2 * ncols);
+
+ const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+
+ const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
+ const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
+ const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2;
+
+ const int kbc0 = int64_t(bidx0 + 0) * (iter_k * iter_j * iter_z_gqa * ne12 * ne03) / item_ct1.get_group_range(2);
+ const int kbc0_stop =
+ int64_t(bidx0 + 1) * (iter_k * iter_j * iter_z_gqa * ne12 * ne03) / item_ct1.get_group_range(2);
+
+ const bool did_not_have_any_data = kbc0 == kbc0_stop;
+ const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
+ const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0;
+ if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
+ return;
+ }
+
+ // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
+ const int sequence = kbc0 /(iter_k*iter_j*iter_z_gqa*ne12);
+ const int z_KV = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
+ const int zt_gqa = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
+ const int jt = (kbc0 - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
+
+ const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
+
+ if (jt*ncols1 + j >= ne01 || zt_gqa*ncols2 + c >= gqa_ratio) {
+ return;
+ }
+
+ dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + zt_Q*D + (j*ne02 + c)*D + tid;
+
+ // Load the partial result that needs a fixup:
+ float dst_val = 0.0f;
+ float max_val = 0.0f;
+ float rowsum = 0.0f;
+ {
+ dst_val = *dst;
+
+ const sycl::float2 tmp = dst_fixup[bidx0 * ncols + jc];
+ max_val = tmp.x();
+ rowsum = tmp.y();
+ }
+
+ // Iterate over previous blocks and compute the combined results.
+ // All SYCL blocks that get here must have a previous block that needs a fixup.
+ int bidx = bidx0 - 1;
+ int kbc_stop = kbc0;
+ while(true) {
+ const int kbc = int64_t(bidx) * (iter_k * iter_j * iter_z_gqa * ne12 * ne03) / item_ct1.get_group_range(2);
+ if (kbc == kbc_stop) { // Did not have any data.
+ bidx--;
+ kbc_stop = kbc;
+ continue;
+ }
+
+ const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid];
+
+ const sycl::float2 tmp = dst_fixup[(item_ct1.get_group_range(2) + bidx) * ncols + jc];
+
+ // Scale the current and new value accumulators depending on the max. values.
+ const float max_val_new = sycl::fmax(max_val, tmp.x());
+
+ const float diff_val = max_val - max_val_new;
+ const float diff_add = tmp.x() - max_val_new;
+
+ const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? sycl::native::exp(diff_val) : 0.0f;
+ const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? sycl::native::exp(diff_add) : 0.0f;
+
+ dst_val = scale_val*dst_val + scale_add*dst_add;
+ rowsum = scale_val * rowsum + scale_add * tmp.y();
+
+ max_val = max_val_new;
+
+ // If this block started in a previous tile we are done and don't need to combine additional partial results.
+ if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
+ break;
+ }
+ bidx--;
+ kbc_stop = kbc;
+ }
+
+ // Write back final result:
+ *dst = dst_val / rowsum;
+}
+
+template <int D> // D == head size
+
+static void flash_attn_combine_results(const float * __restrict__ VKQ_parts,
+ const sycl::float2 * __restrict__ VKQ_meta,
+ float * __restrict__ dst,
+ const int parallel_blocks,
+ uint8_t * dpct_local) {
+ // Dimension 0: threadIdx.x
+ // Dimension 1: blockIdx.x
+ // Dimension 2: blockIdx.y
+ // Dimension 3: blockIdx.z
+ // Memory layout is permuted with [0, 2, 1, 3]
+
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+ const int ne01 = item_ct1.get_group_range(2);
+ const int ne02 = item_ct1.get_group_range(1);
+
+ const int col = item_ct1.get_group(2);
+ const int head = item_ct1.get_group(1);
+ const int sequence = item_ct1.get_group(0);
+
+ const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;
+
+ VKQ_parts += j_dst_unrolled * parallel_blocks*D;
+ VKQ_meta += j_dst_unrolled * parallel_blocks;
+ dst += j_dst_unrolled * D;
+
+ const int tid = item_ct1.get_local_id(2);
+ __builtin_assume(tid < D);
+
+ auto meta = (sycl::float2 *) dpct_local;
+ for (int i = tid; i < 2*parallel_blocks; i += D) {
+ ((float *) meta)[i] = ((const float *)VKQ_meta) [i];
+ }
+
+ item_ct1.barrier(sycl::access::fence_space::local_space);
+
+ float kqmax = meta[0].x();
+ for (int l = 1; l < parallel_blocks; ++l) {
+ kqmax = sycl::max(kqmax, meta[l].x());
+ }
+
+ float VKQ_numerator = 0.0f;
+ float VKQ_denominator = 0.0f;
+ for (int l = 0; l < parallel_blocks; ++l) {
+ const float KQ_max_scale = sycl::native::exp(meta[l].x() - kqmax);
+
+ VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
+ VKQ_denominator += KQ_max_scale * meta[l].y();
+ }
+
+ dst[tid] = VKQ_numerator / VKQ_denominator;
+}
+
+template <fattn_kernel_t fattn_kernel, int warp_size>
+static void lauch_kernel(
+ dpct::dim3 group_range,
+ dpct::dim3 local_range,
+ queue_ptr q,
+ unsigned int local_mem_size,
+ const char* __restrict__ Q,
+ const char* __restrict__ K,
+ const char* __restrict__ V,
+ const char* __restrict__ mask,
+ const char* __restrict__ sinks,
+ const int* __restrict__ KV_max,
+ float* __restrict__ dst,
+ sycl::float2* __restrict__ dst_meta,
+ const float scale,
+ const float max_bias,
+ const float m0,
+ const float m1,
+ const uint32_t n_head_log2,
+ const float logit_softcap,
+ const int32_t ne00,
+ const sycl::uint3 ne01,
+ const int32_t ne02,
+ const int32_t ne03,
+ const int32_t nb01,
+ const int32_t nb02,
+ const int32_t nb03,
+ const int32_t ne10,
+ const int32_t ne11,
+ const int32_t ne12,
+ const int32_t ne13,
+ const int32_t nb11,
+ const int32_t nb12,
+ const int64_t nb13,
+ const int32_t nb21,
+ const int32_t nb22,
+ const int64_t nb23,
+ const int32_t ne31,
+ const int32_t ne32,
+ const int32_t ne33,
+ const int32_t nb31,
+ const int32_t nb32,
+ const int64_t nb33) {
+ GGML_UNUSED(local_mem_size);
+ q->submit([&](sycl::handler &cgh) {
+ cgh.parallel_for(
+ sycl::nd_range<3>(
+ static_cast<sycl::range<3>>(group_range * local_range),
+ static_cast<sycl::range<3>>(local_range)),
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(warp_size)]] {
+ GGML_UNUSED(item_ct1);
+ fattn_kernel(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
+ max_bias, m0, m1, n_head_log2, logit_softcap, ne00,
+ ne01, ne02, ne03, nb01, nb02, nb03, ne10, ne11,
+ ne12, ne13, nb11, nb12, nb13, nb21, nb22, nb23,
+ ne31, ne32, ne33, nb31, nb32, nb33);
+ });
+ });
+}
+
+template <int DV, int ncols1, int ncols2, fattn_kernel_t fattn_kernel, int warp_size>
+void launch_fattn(
+ ggml_backend_sycl_context & ctx, ggml_tensor * dst, const int nwarps, const size_t nbytes_shared,
+ const int nbatch_fa, const bool need_f16_K, const bool need_f16_V, const bool stream_k) {
+
+ constexpr int ncols = ncols1 * ncols2;
+
+ const ggml_tensor * Q = dst->src[0];
+ const ggml_tensor * K = dst->src[1];
+ const ggml_tensor * V = dst->src[2];
+
+ const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs));
+
+ const ggml_tensor * mask = dst->src[3];
+ const ggml_tensor * sinks = dst->src[4];
+
+ ggml_tensor * KQV = dst;
+
+ GGML_ASSERT(Q->type == GGML_TYPE_F32);
+ GGML_ASSERT(KQV->type == GGML_TYPE_F32);
+
+ GGML_ASSERT(Q->nb[0] == ggml_element_size(Q));
+ GGML_ASSERT(K->nb[0] == ggml_element_size(K));
+ GGML_ASSERT(V->nb[0] == ggml_element_size(V));
+
+ GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
+
+ ggml_sycl_pool & pool = ctx.pool();
+ dpct::queue_ptr main_stream = ctx.stream();
+ const int id = ggml_sycl_get_device();
+ const int nsm = ggml_sycl_info().devices[id].nsm;
+
+ ggml_sycl_pool_alloc<sycl::half> K_f16(pool);
+ ggml_sycl_pool_alloc<sycl::half> V_f16(pool);
+ ggml_sycl_pool_alloc<int> KV_max(pool);
+ ggml_sycl_pool_alloc<float> dst_tmp(pool);
+ ggml_sycl_pool_alloc<sycl::float2> dst_tmp_meta(pool);
+
+ const char * K_data = (const char *) K->data;
+ size_t nb11 = K->nb[1];
+ size_t nb12 = K->nb[2];
+ size_t nb13 = K->nb[3];
+
+ const char * V_data = (const char *) V->data;
+ size_t nb21 = V->nb[1];
+ size_t nb22 = V->nb[2];
+ size_t nb23 = V->nb[3];
+
+ if (need_f16_K && K->type != GGML_TYPE_F16) {
+ const size_t bs = ggml_blck_size(K->type);
+ const size_t ts = ggml_type_size(K->type);
+
+ K_f16.alloc(ggml_nelements(K));
+ if (ggml_is_contiguously_allocated(K)) {
+ to_fp16_sycl_t to_fp16 = ggml_get_to_fp16_sycl(K->type, dst);
+ to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
+
+ nb11 = nb11 * bs * sizeof(sycl::half) / ts;
+ nb12 = nb12 * bs * sizeof(sycl::half) / ts;
+ nb13 = nb13 * bs * sizeof(sycl::half) / ts;
+ } else {
+ GGML_ASSERT(K->nb[0] == ts);
+ to_fp16_nc_sycl_t to_fp16 = ggml_get_to_fp16_nc_sycl(K->type);
+ const int64_t s01 = nb11 / ts;
+ const int64_t s02 = nb12 / ts;
+ const int64_t s03 = nb13 / ts;
+ to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream);
+
+ nb11 = K->ne[0] * sizeof(sycl::half);
+ nb12 = K->ne[1] * nb11;
+ nb13 = K->ne[2] * nb12;
+ }
+ K_data = (char *) K_f16.ptr;
+ }
+
+ if (need_f16_V && V->type != GGML_TYPE_F16) {
+ if (V_is_K_view) {
+ V_data = K_data;
+ nb21 = nb11;
+ nb22 = nb12;
+ nb23 = nb13;
+ } else {
+ const size_t bs = ggml_blck_size(V->type);
+ const size_t ts = ggml_type_size(V->type);
+
+ V_f16.alloc(ggml_nelements(V));
+ if (ggml_is_contiguously_allocated(V)) {
+ to_fp16_sycl_t to_fp16 = ggml_get_to_fp16_sycl(V->type, dst);
+ to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
+ V_data = (char *) V_f16.ptr;
+
+ nb21 = nb21 * bs * sizeof(sycl::half) / ts;
+ nb22 = nb22 * bs * sizeof(sycl::half) / ts;
+ nb23 = nb23 * bs * sizeof(sycl::half) / ts;
+ } else {
+ GGML_ASSERT(V->nb[0] == ts);
+ to_fp16_nc_sycl_t to_fp16 = ggml_get_to_fp16_nc_sycl(V->type);
+ const int64_t s01 = nb21 / ts;
+ const int64_t s02 = nb22 / ts;
+ const int64_t s03 = nb23 / ts;
+ to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream);
+
+ nb21 = V->ne[0] * sizeof(sycl::half);
+ nb22 = V->ne[1] * nb21;
+ nb23 = V->ne[2] * nb22;
+ }
+ V_data = (char *) V_f16.ptr;
+ }
+ }
+
+ const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
+ const int gqa_ratio = Q->ne[2] / K->ne[2];
+ const int ntiles_z_gqa = ((gqa_ratio + ncols2 - 1) / ncols2);
+ const int ntiles_total = ntiles_x * ntiles_z_gqa * K->ne[2] * Q->ne[3];
+
+ // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
+ // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
+ // multiple sequences of possibly different lengths.
+ if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
+ const int s31 = mask->nb[1] / sizeof(sycl::half2);
+ const int s33 = mask->nb[3] / sizeof(sycl::half2);
+
+ const dpct::dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1);
+ const dpct::dim3 block_dim_KV_max(FATTN_KQ_STRIDE / 2, 1, 1);
+
+ const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y;
+ const int iter_k = K->ne[1] / FATTN_KQ_STRIDE;
+
+ KV_max.alloc(ne_KV_max);
+ {
+ dpct::has_capability_or_fail(main_stream->get_device(), { sycl::aspect::fp16 });
+
+ main_stream->submit([&](sycl::handler & cgh) {
+ sycl::local_accessor<int, 1> buf_iw_acc_ct1(sycl::range<1>(warp_size), cgh);
+
+ auto mask_data_ct0 = (const sycl::half2 *) mask->data;
+ auto KV_max_ptr_ct1 = KV_max.ptr;
+
+ cgh.parallel_for(sycl::nd_range<3>(blocks_num_KV_max * block_dim_KV_max, block_dim_KV_max),
+ [=](sycl::nd_item<3> item_ct1) {
+ GGML_UNUSED(item_ct1);
+ flash_attn_mask_to_KV_max<ncols1, warp_size>(
+ mask_data_ct0, KV_max_ptr_ct1, iter_k, s31, s33,
+ buf_iw_acc_ct1.get_multi_ptr<sycl::access::decorated::no>().get());
+ });
+ });
+ }
+ SYCL_CHECK(0);
+ }
+
+ const dpct::dim3 block_dim(warp_size, nwarps, 1);
+
+ // Max. number of active blocks limited by occupancy.
+ int max_blocks_per_sm = ggml_sycl_info().devices[id].max_wg_per_cu;
+ int parallel_blocks = max_blocks_per_sm;
+ dpct::dim3 blocks_num;
+ if (stream_k) {
+ // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
+ const int max_blocks = max_blocks_per_sm*nsm;
+ const int nblocks_stream_k = max_blocks;
+ const bool use_stream_k = true;
+
+ blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
+ blocks_num.y = 1;
+ blocks_num.z = 1;
+
+ if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
+ dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2)));
+ }
+ } else {
+ const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size.
+
+ // parallel_blocks must not be larger than what the tensor size allows:
+ parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
+ // todo fix the hard code change
+ // parallel_blocks = ntiles_KQ;
+
+ // If ntiles_total % blocks_per_wave != 0 then some efficiency is lost due to tail effects.
+ // Test whether parallel_blocks can be set to a higher value for better efficiency.
+ const int blocks_per_wave = nsm * max_blocks_per_sm;
+ int nwaves_best = 0;
+ int efficiency_percent_best = 0;
+ for (int parallel_blocks_test = parallel_blocks; parallel_blocks_test <= ntiles_KQ; ++parallel_blocks_test) {
+ const int nblocks_total = ntiles_total * parallel_blocks_test;
+ const int nwaves = (nblocks_total + blocks_per_wave - 1) / blocks_per_wave;
+ const int efficiency_percent = 100 * nblocks_total / (nwaves*blocks_per_wave);
+
+ // Stop trying configurations with more waves if we already have good efficiency to avoid excessive overhead.
+ if (efficiency_percent_best >= 95 && nwaves > nwaves_best) {
+ break;
+ }
+
+ if (efficiency_percent > efficiency_percent_best) {
+ nwaves_best = nwaves;
+ efficiency_percent_best = efficiency_percent;
+ parallel_blocks = parallel_blocks_test;
+ }
+ }
+
+ blocks_num.x = ntiles_x;
+ blocks_num.y = parallel_blocks;
+ blocks_num.z = ntiles_z_gqa*K->ne[2]*Q->ne[3];
+
+ if (parallel_blocks > 1) {
+ dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
+ dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
+ }
+ }
+
+ float scale = 1.0f;
+ float max_bias = 0.0f;
+ float logit_softcap = 0.0f;
+
+ memcpy(&scale, (const float *) KQV->op_params + 0, sizeof(float));
+ memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
+ memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
+ if (logit_softcap != 0.0f) {
+ scale /= logit_softcap;
+ }
+
+ const uint32_t n_head = Q->ne[2];
+ const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head))));
+
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
+ // TODO other tensor dimensions after removal of WMMA kernel:
+ const sycl::uint3 ne01 = init_fastdiv_values(Q->ne[1]);
+
+ GGML_ASSERT(block_dim.x % warp_size == 0);
+
+ lauch_kernel<fattn_kernel, warp_size>(
+ blocks_num, block_dim, main_stream, (unsigned int) nbytes_shared, (const char *) Q->data, K_data, V_data,
+ mask ? ((const char *) mask->data) : nullptr, sinks ? ((const char *) sinks->data) : nullptr, KV_max.ptr,
+ !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, (sycl::float2 *)dst_tmp_meta.ptr, scale, max_bias, m0, m1,
+ n_head_log2, logit_softcap, Q->ne[0], ne01, Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3], K->ne[0],
+ K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13, nb21, nb22, nb23, mask ? mask->ne[1] : 0,
+ mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0, mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0,
+ mask ? mask->nb[3] : 0);
+ SYCL_CHECK(0);
+
+ if (stream_k) {
+ if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
+ const dpct::dim3 block_dim_combine(DV, 1, 1);
+ const dpct::dim3 blocks_num_combine = { blocks_num.x, ncols1, ncols2 };
+
+ main_stream->submit([&](sycl::handler & cgh) {
+ auto KQV_data_ct0 = (float *) KQV->data;
+ auto dst_tmp_meta_ptr_ct1 = dst_tmp_meta.ptr;
+ auto Q_ne_ct2 = Q->ne[1];
+ auto Q_ne_ct3 = Q->ne[2];
+ auto Q_ne_ct4 = Q->ne[3];
+ auto K_ne_ct5 = K->ne[1];
+ auto K_ne_ct6 = K->ne[2];
+
+ cgh.parallel_for(sycl::nd_range<3>(blocks_num_combine * block_dim_combine, block_dim_combine),
+ [=](sycl::nd_item<3> item_ct1) {
+ GGML_UNUSED(item_ct1);
+ flash_attn_stream_k_fixup<DV, ncols1, ncols2>(KQV_data_ct0, dst_tmp_meta_ptr_ct1,
+ Q_ne_ct2, Q_ne_ct3, Q_ne_ct4,
+ K_ne_ct5, K_ne_ct6, nbatch_fa);
+ });
+ });
+ }
+ } else if (parallel_blocks > 1) {
+ const dpct::dim3 block_dim_combine(DV, 1, 1);
+ const dpct::dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);
+ const size_t nbytes_shared_combine = parallel_blocks * sizeof(sycl::float2);
+ main_stream->submit([&](sycl::handler & cgh) {
+ sycl::local_accessor<uint8_t, 1> dpct_local_acc_ct1(sycl::range<1>(nbytes_shared_combine), cgh);
+
+ auto dst_tmp_ptr_ct0 = dst_tmp.ptr;
+ auto dst_tmp_meta_ptr_ct1 = dst_tmp_meta.ptr;
+ auto KQV_data_ct2 = (float *) KQV->data;
+
+ cgh.parallel_for(sycl::nd_range<3>(blocks_num_combine * block_dim_combine, block_dim_combine),
+ [=](sycl::nd_item<3> item_ct1) {
+ GGML_UNUSED(item_ct1);
+ flash_attn_combine_results<DV>(
+ dst_tmp_ptr_ct0, dst_tmp_meta_ptr_ct1, KQV_data_ct2, parallel_blocks,
+ dpct_local_acc_ct1.get_multi_ptr<sycl::access::decorated::no>().get());
+ });
+ });
+ }
+ SYCL_CHECK(0);
+}
--- /dev/null
+#include <sycl/sycl.hpp>
+#include <sycl/ext/oneapi/work_group_static.hpp>
+#include "dpct/helper.hpp"
+#include "common.hpp"
+#include "fattn-common.hpp"
+#include "fattn-tile.hpp"
+#include <cmath>
+#include <float.h>
+namespace syclex = sycl::ext::oneapi::experimental;
+
+void ggml_sycl_flash_attn_ext_tile(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * K = dst->src[1];
+ const ggml_tensor * V = dst->src[2];
+ switch (K->ne[0]) {
+ case 40: {
+ GGML_ASSERT(V->ne[0] == K->ne[0]);
+ ggml_sycl_flash_attn_ext_tile_case< 40, 40>(ctx, dst);
+ } break;
+ case 64: {
+ GGML_ASSERT(V->ne[0] == K->ne[0]);
+ ggml_sycl_flash_attn_ext_tile_case< 64, 64>(ctx, dst);
+ } break;
+ case 72: {
+ GGML_ASSERT(V->ne[0] == K->ne[0]);
+ ggml_sycl_flash_attn_ext_tile_case< 72, 72>(ctx, dst);
+ } break;
+ case 80: {
+ GGML_ASSERT(V->ne[0] == K->ne[0]);
+ ggml_sycl_flash_attn_ext_tile_case< 80, 80>(ctx, dst);
+ } break;
+ case 96: {
+ GGML_ASSERT(V->ne[0] == K->ne[0]);
+ ggml_sycl_flash_attn_ext_tile_case< 96, 96>(ctx, dst);
+ } break;
+ case 112: {
+ GGML_ASSERT(V->ne[0] == K->ne[0]);
+ ggml_sycl_flash_attn_ext_tile_case<112, 112>(ctx, dst);
+ } break;
+ case 128: {
+ GGML_ASSERT(V->ne[0] == K->ne[0]);
+ ggml_sycl_flash_attn_ext_tile_case<128, 128>(ctx, dst);
+ } break;
+ case 256: {
+ GGML_ASSERT(V->ne[0] == K->ne[0]);
+ ggml_sycl_flash_attn_ext_tile_case<256, 256>(ctx, dst);
+ } break;
+ case 576: {
+ GGML_ASSERT(V->ne[0] == 512);
+ ggml_sycl_flash_attn_ext_tile_case<576, 512>(ctx, dst);
+ } break;
+ default: {
+ GGML_ABORT("Unsupported head size");
+ } break;
+ }
+}
--- /dev/null
+#include <sycl/sycl.hpp>
+#include <sycl/ext/oneapi/work_group_static.hpp>
+#include "dpct/helper.hpp"
+#include "common.hpp"
+#include "fattn-common.hpp"
+
+#include <cmath>
+#include <float.h>
+
+namespace syclex = sycl::ext::oneapi::experimental;
+
+#define GGML_SYCL_FATTN_TILE_CONFIG_CASE(DKQ_, DV_, ncols_, nthreads, occupancy, nbatch_fa, nbatch_K) \
+ if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) { \
+ static_assert((nthreads) <= 512, "bad nthreads"); \
+ static_assert((occupancy) <= 8, "bad occupancy"); \
+ static_assert((nbatch_fa) <= 256, "bad nbatch_fa"); \
+ static_assert((nbatch_K) <= 256, "bad nbatch_K"); \
+ return ((nthreads) << 0) | ((occupancy) << 10) | ((nbatch_fa) << 14) | ((nbatch_K) << 23); \
+ } \
+
+static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp16(const int DKQ, const int DV, const int ncols) {
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 64, 40)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 64, 40)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 64, 40)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 64, 40)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 64, 40)
+
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 2, 64, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 2, 64, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 256, 2, 64, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 64, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64)
+
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 64, 72)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 64, 72)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 64, 72)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 64, 72)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 64, 72)
+
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 64, 40)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 64, 40)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 64, 40)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 64, 40)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 64, 40)
+
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 64, 48)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 64, 48)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 64, 48)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 64, 48)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 64, 48)
+
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 64, 56)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 64, 56)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 64, 56)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 64, 56)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 64, 56)
+
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 2, 64, 2, 64, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 2, 64, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 8, 256, 2, 64, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64)
+
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 2, 64, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 2, 64, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64)
+
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
+
+ return 0;
+}
+
+static constexpr uint32_t ggml_sycl_fattn_tile_get_config_fp32(const int DKQ, const int DV, const int ncols) {
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40)
+
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 128, 3, 64, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 3, 32, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 3, 32, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 3, 64, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64)
+
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
+
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40)
+
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48)
+
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56)
+
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 2, 128, 3, 64, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 3, 32, 128)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 8, 128, 3, 64, 128)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 128, 3, 32, 128)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64)
+
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 2, 128, 3, 64, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 3, 32, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 32, 256)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64)
+
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 32, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 32, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64)
+
+ return 0;
+}
+
+static constexpr uint32_t ggml_sycl_fattn_tile_get_config_amd(const int DKQ, const int DV, const int ncols) {
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40)
+
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 3, 32, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 3, 64, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 2, 32, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 128, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 64)
+
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72)
+
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40)
+
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48)
+
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56)
+
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 2, 256, 2, 128, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 2, 64, 128)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 8, 256, 2, 64, 128)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 128)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 2, 64, 32)
+
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 2, 256, 2, 128, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 4, 256, 2, 64, 128)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 128)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128)
+
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64)
+
+ return 0;
+}
+
+static constexpr uint32_t ggml_sycl_fattn_tile_get_config_amd_rdna(const int DKQ, const int DV, const int ncols) {
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40)
+
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 8, 32, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 64, 8, 32, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 5, 128, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 5, 128, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 128, 4, 64, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 128, 5, 64, 64)
+
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 2, 64, 2, 32, 72)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 4, 128, 2, 32, 72)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 8, 256, 2, 32, 72)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 16, 256, 2, 32, 72)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 32, 256, 2, 32, 72)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 72, 72, 64, 256, 2, 32, 72)
+
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40)
+
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48)
+
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56)
+
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 2, 64, 8, 32, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 8, 64, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 8, 128, 8, 64, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 3, 128, 128)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 3, 128, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 3, 64, 64)
+
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 8, 32, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 6, 32, 256)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 8, 128, 6, 32, 256)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128)
+
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 4, 128, 2, 64, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 8, 256, 2, 64, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64)
+ GGML_SYCL_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64)
+
+ return 0;
+}
+
+static constexpr uint32_t ggml_sycl_fattn_tile_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
+ if(fast_fp16_available(cc))
+ return ggml_sycl_fattn_tile_get_config_fp16(DKQ, DV, ncols);
+ else
+ return ggml_sycl_fattn_tile_get_config_fp32(DKQ, DV, ncols);
+}
+
+static constexpr uint32_t ggml_sycl_fattn_tile_get_config(const int DKQ, const int DV, const int ncols) {
+#ifdef SYCL_FAST_FP16
+ return ggml_sycl_fattn_tile_get_config_fp16(DKQ, DV, ncols);
+#else
+ return ggml_sycl_fattn_tile_get_config_fp32(DKQ, DV, ncols);
+#endif // SYCL_FAST_FP16
+}
+
+static int ggml_sycl_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols, const int cc) {
+ return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 0) & ((1 << 10) - 1);
+}
+
+static constexpr int ggml_sycl_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols) {
+ return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols) >> 0) & ((1 << 10) - 1);
+}
+
+static int ggml_sycl_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols, const int cc) {
+ return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 10) & ((1 << 4) - 1);
+}
+
+static constexpr int ggml_sycl_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols) {
+ return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols) >> 10) & ((1 << 4) - 1);
+}
+
+static int ggml_sycl_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols, const int cc) {
+ return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 14) & ((1 << 9) - 1);
+}
+
+static constexpr int ggml_sycl_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols) {
+ return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols) >> 14) & ((1 << 9) - 1);
+}
+
+static int ggml_sycl_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols, const int cc) {
+ return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 23) & ((1 << 9) - 1);
+}
+
+static constexpr int ggml_sycl_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols) {
+ return (ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols) >> 23) & ((1 << 9) - 1);
+}
+
+template <int warp_size, int nwarps, int I, int J, int J_padding, bool oob_check>
+static __dpct_inline__ void flash_attn_tile_load_tile(const sycl::half2 * const __restrict__ KV,
+ sycl::half2 * const __restrict__ tile_KV,
+ const int stride_KV,
+ const int i_sup) {
+ constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();
+ constexpr int cpy_ne = cpy_nb / 4;
+
+ auto load = [&] (const int n) {
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+ const int stride_j = warp_size >> n;
+
+ if (stride_j == 0) {
+ return;
+ }
+
+ const int j0_start = stride_j == warp_size ? 0 : ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (2*stride_j);
+ const int j0_stop = ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (1*stride_j);
+ const int stride_i = warp_size / stride_j;
+
+ if (j0_start == j0_stop) {
+ return;
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) {
+ const int i = i0 + item_ct1.get_local_id(1) * stride_i +
+ (stride_j == warp_size ? 0 : item_ct1.get_local_id(2) / stride_j);
+
+ if (i0 + nwarps*stride_i <= I || i < I) {
+#pragma unroll
+ for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) {
+ const int j = j0 * cpy_ne + (stride_j == warp_size ? item_ct1.get_local_id(2) :
+ item_ct1.get_local_id(2) % stride_j) *
+ cpy_ne;
+
+ const __dpct_align__(16) sycl::half2 zero[cpy_ne] = {
+ { 0.0f, 0.0f }
+ };
+ ggml_sycl_memcpy_1<cpy_nb>(
+ tile_KV + i*(J/2 + J_padding) + j,
+ !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
+ }
+ }
+ }
+ };
+ // 1: max 64*16=512 bytes, 512 half
+ // 2: max 32*16=512 bytes, 256 half
+ // 3: max 16*16=256 bytes, 128 half
+ // 4: max 8*16=128 bytes, 64 half
+ // 5: max 4*16= 64 bytes, 32 half
+ // 6: max 2*16= 32 bytes, 16 half
+ // 7: max 1*16= 16 bytes, 8 half
+ static_assert(J % 8 == 0, "bad J");
+ static_assert((J/2) % cpy_ne == 0, "bad J");
+ ggml_sycl_unroll<7>{}(load);
+}
+
+template <int warp_size, int nwarps, int I, int J, int J_padding, bool oob_check>
+static __dpct_inline__ void flash_attn_tile_load_tile(const sycl::half2 * const __restrict__ KV,
+ float * const __restrict__ tile_KV,
+ const int stride_KV,
+ const int i_sup) {
+ constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();
+ constexpr int cpy_ne = cpy_nb / 4;
+
+ auto load = [&] (const int n) {
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+ const int stride_j = warp_size >> n;
+
+ if (stride_j == 0) {
+ return;
+ }
+
+ const int j0_start = stride_j == warp_size ? 0 : (J/cpy_ne) - (J/cpy_ne) % (2*stride_j);
+ const int j0_stop = (J/cpy_ne) - (J/cpy_ne) % (1*stride_j);
+ const int stride_i = warp_size / stride_j;
+
+ if (j0_start == j0_stop) {
+ return;
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) {
+ const int i = i0 + item_ct1.get_local_id(1) * stride_i +
+ (stride_j == warp_size ? 0 : item_ct1.get_local_id(2) / stride_j);
+
+ if (i0 + nwarps*stride_i <= I || i < I) {
+#pragma unroll
+ for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) {
+ const int j = j0 * (cpy_ne / 2) + (stride_j == warp_size ? item_ct1.get_local_id(2) :
+ item_ct1.get_local_id(2) % stride_j) *
+ (cpy_ne / 2);
+
+ const sycl::half2 zero[cpy_ne / 2] = {
+ { 0.0f, 0.0f }
+ };
+ __dpct_align__(16) sycl::half2 tmp_h2[cpy_ne / 2];
+ ggml_sycl_memcpy_1<sizeof(tmp_h2)>(
+ tmp_h2, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero);
+
+ __dpct_align__(16) sycl::float2 tmp_f2[cpy_ne / 2];
+#pragma unroll
+ for (int l = 0; l < cpy_ne/2; ++l) {
+ tmp_f2[l] = tmp_h2[l].template convert<float, sycl::rounding_mode::automatic>();
+ }
+ ggml_sycl_memcpy_1<sizeof(tmp_f2)>(tile_KV + i*(J + J_padding) + 2*j, tmp_f2);
+ }
+ }
+ }
+ };
+ // 1: max 32*16=512 bytes, 128 float
+ // 2: max 16*16=256 bytes, 64 float
+ // 3: max 8*16=128 bytes, 32 float
+ // 4: max 4*16= 64 bytes, 16 float
+ // 5: max 2*16= 32 bytes, 8 float
+ static_assert(J % 8 == 0, "bad J");
+ static_assert(J % cpy_ne == 0, "bad J");
+ ggml_sycl_unroll<5>{}(load);
+}
+
+// Function that performs a single iteration in for the KQ matrix multiplication:
+template <int warp_size,
+ int nwarps,
+ int ncols1,
+ int ncols2,
+ int DKQ,
+ int nbatch_fa,
+ int nbatch_K,
+ bool use_logit_softcap,
+ bool oob_check,
+ typename T_vec_dot>
+static __dpct_inline__ void flash_attn_tile_iter_KQ(T_vec_dot * const Q_tmp,
+ const sycl::half2 * const __restrict__ K_h2,
+ T_vec_dot * const KV_tmp,
+ const int stride_K2,
+ const int k_VKQ_0,
+ const int k_VKQ_sup,
+ const int k_KQ_0,
+ float * KQ_acc) {
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+ constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();
+ constexpr int cpy_ne = cpy_nb / 4;
+
+ constexpr int ncols = ncols1*ncols2;
+ constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp
+ constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column
+
+ flash_attn_tile_load_tile<warp_size, nwarps, nbatch_fa, nbatch_K, cpy_ne, oob_check>
+ (K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup);
+ item_ct1.barrier();
+
+#ifdef SYCL_FAST_FP16
+ static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K");
+#pragma unroll
+ for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K/2; k_KQ_1 += cpy_ne) {
+ __dpct_align__(16) sycl::half2 K_k[nbatch_fa / (np * warp_size)][cpy_ne];
+ __dpct_align__(16) sycl::half2 Q_k[cpw][cpy_ne];
+#else
+ static_assert(nbatch_K % cpy_ne == 0, "bad nbatch_K");
+#pragma unroll
+ for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K; k_KQ_1 += cpy_ne) {
+ __dpct_align__(16) float K_k[nbatch_fa/(np*warp_size)][cpy_ne];
+ __dpct_align__(16) float Q_k[cpw][cpy_ne];
+#endif // SYCL_FAST_FP16
+
+#pragma unroll
+ for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {
+ const int i_KQ = i_KQ_0 + (item_ct1.get_local_id(1) % np) * warp_size + item_ct1.get_local_id(2);
+
+#ifdef SYCL_FAST_FP16
+ ggml_sycl_memcpy_1<cpy_nb>(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K/2 + cpy_ne) + k_KQ_1]);
+#else
+ ggml_sycl_memcpy_1<cpy_nb>(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K + cpy_ne) + k_KQ_1]);
+#endif // SYCL_FAST_FP16
+ }
+#pragma unroll
+ for (int jc0 = 0; jc0 < cpw; ++jc0) {
+ const int jc = jc0 + (item_ct1.get_local_id(1) / np) * cpw;
+
+#ifdef SYCL_FAST_FP16
+ ggml_sycl_memcpy_1<cpy_nb>(&Q_k[jc0], &Q_tmp[jc*(DKQ/2) + k_KQ_0/2 + k_KQ_1]);
+#else
+ ggml_sycl_memcpy_1<cpy_nb>(&Q_k[jc0], &Q_tmp[jc* DKQ + k_KQ_0 + k_KQ_1]);
+#endif // SYCL_FAST_FP16
+ }
+
+#pragma unroll
+ for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {
+#pragma unroll
+ for (int jc0 = 0; jc0 < cpw; ++jc0) {
+#pragma unroll
+ for (int k = 0; k < cpy_ne; ++k) {
+ ggml_sycl_mad(KQ_acc[i_KQ_0/(np*warp_size)*cpw + jc0], K_k[i_KQ_0/(np*warp_size)][k], Q_k[jc0][k]);
+ }
+ }
+ }
+ }
+
+ if (k_KQ_0 + nbatch_K < DKQ) {
+ item_ct1.barrier(); // Sync not needed on last iteration.
+ }
+}
+
+// Function that performs a single iteration of the main loop over up to nbatch_fa tokens.
+template <int warp_size,
+ int nwarps,
+ int ncols1,
+ int ncols2,
+ int DKQ,
+ int DV,
+ int nbatch_fa,
+ int nbatch_K,
+ bool use_logit_softcap,
+ bool oob_check,
+ typename T_vec_dot,
+ typename T_KQ,
+ typename T_acc>
+/*
+The total declared local variable size in device function flash_attn_tile_iter exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure.
+*/
+static __dpct_inline__ void flash_attn_tile_iter(T_vec_dot * const Q_tmp,
+ const sycl::half2 * const __restrict__ K_h2,
+ const sycl::half2 * const __restrict__ V_h2,
+ const sycl::half * const __restrict__ mask,
+ const sycl::uint3 ne01,
+ const float logit_softcap,
+ const float slope,
+ T_KQ * const KQ,
+ T_vec_dot * const KV_tmp,
+ const int stride_K2,
+ const int stride_V2,
+ const int stride_mask,
+ float * const KQ_max,
+ float * const KQ_sum,
+ T_acc * const VKQ,
+ const int k_VKQ_0,
+ const int k_VKQ_max,
+ const int col_Q_0,
+ float * KQ_max_new_shared) {
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+ constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();
+ constexpr int cpy_ne = cpy_nb / 4;
+
+ constexpr int ncols = ncols1*ncols2;
+ constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp
+ constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column
+
+ constexpr int DVp = (DV + 2*warp_size - 1) & ~(2*warp_size - 1); // DV padded to multiple of 2*warp_size.
+
+#ifdef SYCL_FAST_FP16
+ constexpr int KQ_cs = cpw < 2*cpy_ne ? cpw : 2*cpy_ne;
+#else
+ constexpr int KQ_cs = cpw < 1*cpy_ne ? cpw : 1*cpy_ne;
+#endif // SYCL_FAST_FP16
+ static_assert(cpw % KQ_cs == 0, "bad KQ_cs");
+ const int k_VKQ_sup = k_VKQ_max - k_VKQ_0; // k supremum, only smaller k values have valid KV data
+
+ float KQ_max_new[cpw];
+#pragma unroll
+ for (int jc0 = 0; jc0 < cpw; ++jc0) {
+ KQ_max_new[jc0] = KQ_max[jc0];
+ }
+
+ float KQ_acc[nbatch_fa/(np*warp_size) * cpw] = {0.0f}; // Accumulators for KQ matrix multiplication.
+
+ // KQ = K @ Q matrix multiplication:
+ constexpr int nbatch_K_last = DKQ % nbatch_K;
+#pragma unroll
+ for (int k_KQ_0 = 0; k_KQ_0 < DKQ - nbatch_K_last; k_KQ_0 += nbatch_K) {
+ flash_attn_tile_iter_KQ<warp_size, nwarps, ncols1, ncols2, DKQ, nbatch_fa, nbatch_K, use_logit_softcap, oob_check>(
+ Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc);
+ }
+ if (nbatch_K_last > 0) {
+ constexpr int k_KQ_0 = DKQ - nbatch_K_last;
+ flash_attn_tile_iter_KQ<warp_size, nwarps, ncols1, ncols2, DKQ, nbatch_fa, nbatch_K_last, use_logit_softcap, oob_check>(
+ Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc);
+ }
+
+ // Apply logit softcap + mask, update KQ_max:
+#pragma unroll
+ for (int jc0 = 0; jc0 < cpw; ++jc0) {
+ const int j = fastmodulo(col_Q_0 + (jc0 + (item_ct1.get_local_id(1) / np) * cpw) / ncols2, ne01);
+
+#pragma unroll
+ for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) {
+ const int i_KQ = i_KQ_0 + (item_ct1.get_local_id(1) % np) * warp_size + item_ct1.get_local_id(2);
+
+#if defined(SYCL_FAST_FP16) && !defined(GGML_SYCL_F16)
+ // Without the v_dot2_f32_f16 instruction there is a higher risk of numerical overflow in the KQ calculation.
+ // Therefore, scale down Q values and apply the inverse scale the FP32 KQ values afterwards again.
+ KQ_acc[i_KQ_0/(np*warp_size)*cpw + jc0] *= 4.0f;
+#endif // defined(SYCL_FAST_FP16) && !defined(GGML_SYCL_F16)
+
+ if (use_logit_softcap) {
+ KQ_acc[(i_KQ_0 / (np * warp_size)) * cpw + jc0] =
+ logit_softcap * sycl::tanh((float) KQ_acc[(i_KQ_0 / (np * warp_size)) * cpw + jc0]);
+ }
+
+ if (!oob_check || i_KQ < k_VKQ_sup) {
+ KQ_acc[(i_KQ_0 / (np * warp_size)) * cpw + jc0] +=
+ (ncols2 > 1 || mask) ? slope * sycl::vec<sycl::half, 1>(mask[j * stride_mask + k_VKQ_0 + i_KQ])
+ .convert<float, sycl::rounding_mode::automatic>()[0] :
+ 0.0f;
+
+ KQ_max_new[jc0] =
+ sycl::fmax((float) KQ_max_new[jc0],
+ (float) (KQ_acc[(i_KQ_0 / (np * warp_size)) * cpw + jc0] + FATTN_KQ_MAX_OFFSET));
+ }
+ }
+
+ KQ_max_new[jc0] = warp_reduce_max<warp_size>(KQ_max_new[jc0]);
+ }
+
+ if constexpr (np == 1) {
+ item_ct1.barrier();
+ } else {
+ static_assert(cpw == 1, "bad cpw");
+
+ if (item_ct1.get_local_id(2) == 0) {
+ KQ_max_new_shared[item_ct1.get_local_id(1)] = KQ_max_new[0];
+ }
+ item_ct1.barrier();
+ KQ_max_new[0] = KQ_max_new_shared[(item_ct1.get_local_id(1) & ~(np - 1)) + item_ct1.get_local_id(2) % np];
+ KQ_max_new[0] = warp_reduce_max<np>(KQ_max_new[0]);
+ }
+
+ // Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators:
+#pragma unroll
+ for (int jc0 = 0; jc0 < cpw; jc0 += KQ_cs) {
+#ifdef SYCL_FAST_FP16
+ __dpct_align__(16) sycl::half tmp[nbatch_fa / (np * warp_size)][KQ_cs];
+#else
+ __dpct_align__(16) float tmp[nbatch_fa/(np*warp_size)][KQ_cs];
+#endif // SYCL_FAST_FP16
+
+#pragma unroll
+ for (int jc1 = 0; jc1 < KQ_cs; ++jc1) {
+ const int jc = jc0 + jc1;
+
+ const float KQ_max_scale = sycl::native::exp((float) (KQ_max[jc] - KQ_max_new[jc]));
+ KQ_max[jc] = KQ_max_new[jc];
+
+ float KQ_sum_add = 0.0f;
+#pragma unroll
+ for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {
+ const float val =
+ !oob_check || i0 + (item_ct1.get_local_id(1) % np) * warp_size + item_ct1.get_local_id(2) <
+ static_cast<uint32_t>(k_VKQ_sup) ?
+ sycl::native::exp((float) (KQ_acc[(i0 / (np * warp_size)) * cpw + jc] - KQ_max[jc])) :
+ 0.0f;
+ KQ_sum_add += val;
+ tmp[i0/(np*warp_size)][jc1] = val;
+ }
+ KQ_sum[jc] = KQ_sum[jc]*KQ_max_scale + KQ_sum_add;
+
+#ifdef SYCL_FAST_FP16
+ const sycl::half2 KQ_max_scale_h2 = sycl::half2(KQ_max_scale, KQ_max_scale);
+#pragma unroll
+ for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
+ VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].x() *= KQ_max_scale_h2.x();
+ VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].y() *= KQ_max_scale_h2.y();
+ }
+#else
+#pragma unroll
+ for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
+ VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].x() *= KQ_max_scale;
+ VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].y() *= KQ_max_scale;
+ }
+#endif // SYCL_FAST_FP16
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {
+ const int i = i0 + (item_ct1.get_local_id(1) % np) * warp_size + item_ct1.get_local_id(2);
+
+ ggml_sycl_memcpy_1<sizeof(tmp[0])>(
+ KQ + (jc0 / KQ_cs + (item_ct1.get_local_id(1) / np) * (cpw / KQ_cs)) * (nbatch_fa * KQ_cs) + i * KQ_cs,
+ tmp[i0 / (np * warp_size)]);
+ }
+ }
+
+ // VKQ = V @ KQ matrix multiplication:
+ static_assert(DV <= DKQ, "bad DV");
+ static_assert(DV % nbatch_K == 0 || (nbatch_K % 3 == 0 && DV % (nbatch_K*2/3) == 0), "bad nbatch_K");
+ constexpr int nbatch_V = (DV % nbatch_K == 0 ? nbatch_K : nbatch_K*2/3) * nbatch_fa / DV; // Number of V columns that fit in SRAM for K.
+ static_assert(nbatch_fa % nbatch_V == 0, "bad nbatch_V");
+ static_assert(nbatch_V % np == 0, "bad nbatch_V");
+#pragma unroll
+ for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) {
+ flash_attn_tile_load_tile<warp_size, nwarps, nbatch_V, DV, 0, oob_check>
+ (V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0);
+ item_ct1.barrier();
+
+#ifdef SYCL_FAST_FP16
+#pragma unroll
+ for (int k1 = 0; k1 < nbatch_V; k1 += np) {
+ __dpct_align__(16) sycl::half2 V_k[(DVp / 2) / warp_size];
+ __dpct_align__(16) sycl::half2 KQ_k[cpw];
+
+ constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;
+#pragma unroll
+ for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
+ ggml_sycl_memcpy_1<cpy_ne_D * 4>(&V_k[i0 / warp_size],
+ &KV_tmp[(k1 + item_ct1.get_local_id(1) % np) * (DV / 2) + i0 +
+ item_ct1.get_local_id(2) * cpy_ne_D]);
+ }
+#pragma unroll
+ for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) {
+ const int jc_KQ = jc_VKQ_0 / KQ_cs + (item_ct1.get_local_id(1) / np) * (cpw / KQ_cs);
+
+ __dpct_align__(16) sycl::half tmp[KQ_cs];
+ ggml_sycl_memcpy_1<KQ_cs * sizeof(sycl::half)>(
+ &tmp, KQ + jc_KQ * (nbatch_fa * KQ_cs) + (k0 + k1 + item_ct1.get_local_id(1) % np) * KQ_cs);
+#pragma unroll
+ for (int jc_VKQ_1 = 0; jc_VKQ_1 < KQ_cs; ++jc_VKQ_1) {
+ KQ_k[jc_VKQ_0 + jc_VKQ_1] = sycl::half2(tmp[jc_VKQ_1]);
+ }
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
+#pragma unroll
+ for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) {
+ VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].x() +=
+ V_k[i0/warp_size].x()*KQ_k[jc_VKQ_0].x();
+ VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].y() +=
+ V_k[i0/warp_size].y()*KQ_k[jc_VKQ_0].y();
+ }
+ }
+ }
+#else
+#pragma unroll
+ for (int k1 = 0; k1 < nbatch_V; k1 += np) {
+ __dpct_align__(16) sycl::float2 V_k[(DVp/2)/warp_size];
+ __dpct_align__(16) float KQ_k[cpw];
+
+ constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
+#pragma unroll
+ for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
+ ggml_sycl_memcpy_1<cpy_ne_D*4>(&V_k[i0/(2*warp_size)], &KV_tmp[(k1 + item_ct1.get_local_id(1) % np)*DV + i0 + item_ct1.get_local_id(2)*cpy_ne_D]);
+ }
+#pragma unroll
+ for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) {
+ const int jc_KQ = jc_VKQ_0/KQ_cs + (item_ct1.get_local_id(1) / np)*(cpw/KQ_cs);
+
+ ggml_sycl_memcpy_1<KQ_cs*sizeof(float)>(
+ &KQ_k[jc_VKQ_0], KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + item_ct1.get_local_id(1) % np)*KQ_cs);
+ }
+
+#pragma unroll
+ for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
+#pragma unroll
+ for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) {
+ VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].x() += V_k[i0/warp_size].x()*KQ_k[jc_VKQ_0];
+ VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].y() += V_k[i0/warp_size].y()*KQ_k[jc_VKQ_0];
+ }
+ }
+ }
+#endif // SYCL_FAST_FP16
+ item_ct1.barrier();
+ }
+}
+
+template <int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, int warp_size> // D == head size
+/*
+The total declared local variable size in device function flash_attn_tile exceeds 128 bytes and may cause high register pressure. Consult with your hardware vendor to find the total register size available and adjust the code, or use smaller sub-group size to avoid high register pressure.
+*/
+static void flash_attn_tile(const char * Q,
+ const char * K,
+ const char * V,
+ const char * mask,
+ const char * sinks,
+ const int * KV_max,
+ float * dst,
+ sycl::float2 * dst_meta,
+ const float scale,
+ const float max_bias,
+ const float m0,
+ const float m1,
+ const uint32_t n_head_log2,
+ const float logit_softcap,
+ const int32_t ne00,
+ const sycl::uint3 ne01,
+ const int32_t ne02,
+ const int32_t ne03,
+ const int32_t nb01,
+ const int32_t nb02,
+ const int32_t nb03,
+ const int32_t ne10,
+ const int32_t ne11,
+ const int32_t ne12,
+ const int32_t ne13,
+ const int32_t nb11,
+ const int32_t nb12,
+ const int64_t nb13,
+ const int32_t nb21,
+ const int32_t nb22,
+ const int64_t nb23,
+ const int32_t ne31,
+ const int32_t ne32,
+ const int32_t ne33,
+ const int32_t nb31,
+ const int32_t nb32,
+ const int64_t nb33) {
+#ifdef SYCL_FLASH_ATTN
+ // Skip unused kernel variants for faster compilation:
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+ if ((use_logit_softcap && !(DV == 128 || DV == 256))) {
+ GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
+ max_bias, m0, m1, n_head_log2, logit_softcap,
+ ne00, ne01, ne02, ne03,
+ nb01, nb02, nb03,
+ ne10, ne11, ne12, ne13,
+ nb11, nb12, nb13,
+ nb21, nb22, nb23,
+ ne31, ne32, ne33,
+ nb31, nb32, nb33);
+ return;
+ }
+
+ static_assert(ggml_sycl_fattn_tile_get_config(DKQ, DV, ncols1*ncols2) != 0, "kernel config not defined");
+
+ constexpr int ncols = ncols1*ncols2;
+
+ constexpr int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, ncols1*ncols2) / warp_size;
+ constexpr int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, ncols1*ncols2);
+ constexpr int nbatch_K = ggml_sycl_fattn_tile_get_nbatch_K (DKQ, DV, ncols1*ncols2);
+
+ // In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+ const int col_Q_0 = item_ct1.get_group(2) * ncols1; // Index of the first Q column for this SYCL block to work on.
+
+ const int sequence = item_ct1.get_group(0) / (ne02 / ncols2);
+ const int head0 = item_ct1.get_group(0) * ncols2 - sequence * ne02; // == item_ct1.get_group(0) % (ne02/ncols2)
+ const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+ const float * Q_f = (const float *) (Q + nb03*sequence + nb02* head0);
+ const sycl::half2 * K_h2 = (const sycl::half2 *) (K + nb13 * sequence + nb12 * (head0 / gqa_ratio));
+ const sycl::half2 * V_h2 =
+ (const sycl::half2 *) (V + nb23 * sequence + nb22 * (head0 / gqa_ratio)); // K and V have same shape
+
+ const sycl::half * maskh = mask ? (const sycl::half *) (mask + nb33 * (sequence % ne33)) : nullptr;
+
+ const int stride_K2 = nb11 / sizeof(sycl::half2);
+ const int stride_V2 = nb21 / sizeof(sycl::half2);
+ const int stride_mask = nb31 / sizeof(sycl::half);
+
+ const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
+
+ constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();
+ constexpr int cpy_ne = cpy_nb / 4;
+
+ constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp.
+ constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // Number of parallel warps per Q column.
+
+ static_assert(cpw == 1 || np == 1, "bad cpw / np");
+ static_assert(nbatch_fa % (np*warp_size) == 0, "nbatch_fa % (np*warp_size) != 0");
+
+ constexpr int DKQp = (DKQ + 2*warp_size - 1) & ~(2*warp_size - 1); // DKQ padded to multiple of 2*warp_size.
+ constexpr int DVp = (DV + 2*warp_size - 1) & ~(2*warp_size - 1); // DV padded to multiple of 2*warp_size.
+
+ // Q_tmp == SRAM buffer to hold Q data for the entire lifetime of the kernel.
+ // KV_tmp == SRAM buffer to hold fragments of K/V data while iterating over ne11.
+ // KV_tmp is padded to avoid memory conflicts for K (cpy_ne) and OOB accesses for V (DVp-DV).
+ // KQ == SRAM buffer to hold KQ fragments between KQ and VKQ matrix multiplications.
+ // VKQ == Accumulators in registers for the final VKQ result.
+
+
+#ifdef SYCL_FAST_FP16
+ constexpr size_t lsm_size1 = ncols * DKQ/2 ;
+ constexpr size_t lsm_size2 = nbatch_fa * (nbatch_K/2 + cpy_ne) + DVp-DV ;
+ constexpr size_t lsm_size3 = ncols * nbatch_fa;
+ constexpr size_t lsm_size4 = nwarps;
+
+ constexpr size_t local_share_mem_size = lsm_size1 * sizeof(sycl::half2) +
+ lsm_size2 * sizeof(sycl::half2) +
+ lsm_size3 * sizeof(sycl::half) +
+ lsm_size4 * sizeof(float);
+
+ syclex::work_group_static<char[local_share_mem_size]> lsm;
+
+ sycl::half2 *Q_tmp = (sycl::half2 *)&lsm;
+ sycl::half2 *KV_tmp = (sycl::half2*)(Q_tmp +lsm_size1);
+ sycl::half *KQ = (sycl::half *)(KV_tmp+lsm_size2);
+ float *KQ_max_new_shared = (float *)(KQ+lsm_size3);
+
+ __dpct_align__(16) sycl::half2 VKQ[cpw * ((DVp / 2) / warp_size)] = {
+ { 0.0f, 0.0f }
+ };
+#else
+ constexpr size_t lsm_size1 = ncols * DKQ ;
+ constexpr size_t lsm_size2 = nbatch_fa * (nbatch_K + cpy_ne) + DVp-DV;
+ constexpr size_t lsm_size3 = ncols * nbatch_fa;
+ constexpr size_t lsm_size4 = nwarps;
+
+ constexpr size_t local_share_mem_size = (lsm_size1 + lsm_size2 +lsm_size3 + lsm_size4) * sizeof(float);
+
+ syclex::work_group_static<char[local_share_mem_size]> lsm;
+
+ float *Q_tmp = (float *)&lsm;
+ float *KV_tmp = Q_tmp +lsm_size1;
+ float *KQ = KV_tmp+lsm_size2;
+ float *KQ_max_new_shared = KQ+lsm_size3;
+
+ __dpct_align__(16) sycl::float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}};
+
+
+#endif // SYCL_FAST_FP16
+
+ float KQ_max[cpw] = {};
+
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+ KQ_max[j0/nwarps] = -FLT_MAX/2.0f;
+ }
+ float KQ_sum[cpw] = {0.0f};
+
+ // Load Q data, convert to FP16 if fast:
+#pragma unroll
+ for (int jc0 = 0; jc0 < cpw; ++jc0) {
+ const int jc = jc0 + (item_ct1.get_local_id(1) / np) * cpw;
+
+ const int j = jc / ncols2;
+ const int c = jc % ncols2;
+
+ constexpr int cpy_ne_D = cpy_ne < DKQp/warp_size ? cpy_ne : DKQp/warp_size;
+
+#pragma unroll
+ for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) {
+ if (i0 + np * warp_size * cpy_ne_D <= DKQ ||
+ i0 + (item_ct1.get_local_id(1) % np) * (warp_size * cpy_ne_D) + item_ct1.get_local_id(2) * cpy_ne_D <
+ DKQ) {
+ __dpct_align__(16) float tmp_f[cpy_ne_D] = { 0.0f };
+ ggml_sycl_memcpy_1<sizeof(tmp_f)>(
+ tmp_f, &Q_f[c * (nb02 / sizeof(float)) + fastmodulo(col_Q_0 + j, ne01) * (nb01 / sizeof(float)) +
+ i0 + (item_ct1.get_local_id(1) % np) * (warp_size * cpy_ne_D) +
+ item_ct1.get_local_id(2) * cpy_ne_D]);
+
+#pragma unroll
+ for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
+ tmp_f[i1] *= scale;
+ }
+
+#ifdef SYCL_FAST_FP16
+ __dpct_align__(16) sycl::half2 tmp_h2[cpy_ne_D / 2];
+#pragma unroll
+ for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) {
+ tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]);
+#if defined(SYCL_FAST_FP16) && !defined(GGML_SYCL_F16)
+ // Without the v_dot2_f32_f16 instruction there is a higher risk of numerical overflow in the KQ calculation.
+ // Therefore, scale down Q values and apply the inverse scale the FP32 KQ values afterwards again.
+ tmp_h2[i1 / 2] *= sycl::half2(0.25f, 0.25f);
+#endif // defined(SYCL_FAST_FP16) && !defined(GGML_SYCL_F16)
+ }
+ ggml_sycl_memcpy_1<sizeof(tmp_h2)>(
+ &Q_tmp[jc * (DKQ / 2) + i0 / 2 + (item_ct1.get_local_id(1) % np) * (warp_size * cpy_ne_D / 2) +
+ item_ct1.get_local_id(2) * (cpy_ne_D / 2)],
+ tmp_h2);
+#else
+ ggml_sycl_memcpy_1<sizeof(tmp_f)>(
+ &Q_tmp[jc* DKQ + i0 + (item_ct1.get_local_id(1) % np)*(warp_size*cpy_ne_D) + item_ct1.get_local_id(2)* cpy_ne_D],
+ tmp_f);
+#endif // SYCL_FAST_FP16
+ }
+ }
+ }
+
+ item_ct1.barrier();
+
+ // Main loop over KV cache:
+ const int k_VKQ_max = KV_max ? KV_max[sequence * item_ct1.get_group_range(2) + item_ct1.get_group(2)] : ne11;
+ if (ncols2 == 1) {
+ // Branch with out-of-bounds checks.
+ int k_VKQ_0 = item_ct1.get_group(1) * nbatch_fa;
+ while (k_VKQ_0 < k_VKQ_max - nbatch_fa) {
+ constexpr bool oob_check = false;
+ flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap,
+ oob_check>(Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, stride_K2,
+ stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0,
+ KQ_max_new_shared);
+ k_VKQ_0 += item_ct1.get_group_range(1) * nbatch_fa;
+ }
+ if (k_VKQ_0 < k_VKQ_max) {
+ constexpr bool oob_check = true;
+ flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap,
+ oob_check>(Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, stride_K2,
+ stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0,
+ KQ_max_new_shared);
+ }
+ } else {
+ // Branch without out-of-bounds checks.
+ for (int k_VKQ_0 = item_ct1.get_group(1) * nbatch_fa; k_VKQ_0 < k_VKQ_max;
+ k_VKQ_0 += item_ct1.get_group_range(1) * nbatch_fa) {
+
+ constexpr bool oob_check = false;
+ flash_attn_tile_iter<warp_size, nwarps, ncols1, ncols2, DKQ, DV, nbatch_fa, nbatch_K, use_logit_softcap,
+ oob_check>(Q_tmp, K_h2, V_h2, maskh, ne01, logit_softcap, slope, KQ, KV_tmp, stride_K2,
+ stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max, col_Q_0,
+ KQ_max_new_shared);
+ }
+ }
+
+#pragma unroll
+ for (int jc0 = 0; jc0 < cpw; ++jc0) {
+ KQ_sum[jc0] = warp_reduce_sum<warp_size>(KQ_sum[jc0]);
+ }
+
+ if constexpr (np > 1) {
+ static_assert(cpw == 1, "bad cpw");
+ static_assert(nbatch_fa*nbatch_K >= nwarps*DVp, "KV_tmp too small");
+
+#ifdef SYCL_FAST_FP16
+ sycl::half2 * VKQ_combine = (sycl::half2 *) KV_tmp;
+#else
+ float * VKQ_combine = (float *) KV_tmp;
+#endif // SYCL_FAST_FP16
+
+ float * KQ_sum_combine = (float *) Q_tmp;
+
+ if (item_ct1.get_local_id(1) % np != 0) {
+
+#ifdef SYCL_FAST_FP16
+ constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size;
+#pragma unroll
+ for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
+ ggml_sycl_memcpy_1<cpy_ne_D * 4>(
+ &VKQ_combine[item_ct1.get_local_id(1) * (DVp / 2) + i0 + item_ct1.get_local_id(2) * cpy_ne_D],
+ &VKQ[i0 / warp_size]);
+ }
+#else
+
+ constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
+
+#pragma unroll
+ for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
+ ggml_sycl_memcpy_1<cpy_ne_D*4>(
+ &VKQ_combine[item_ct1.get_local_id(1)*DVp + i0 + item_ct1.get_local_id(2)*cpy_ne_D], ((const float *) VKQ) + i0/warp_size);
+ }
+#endif // SYCL_FAST_FP16
+
+ if (item_ct1.get_local_id(2) == 0) {
+ KQ_sum_combine[item_ct1.get_local_id(1)] = KQ_sum[0];
+ }
+ return;
+ }
+
+ item_ct1.barrier();
+
+#pragma unroll
+ for (int ip = 1; ip < np; ++ip) {
+#ifdef SYCL_FAST_FP16
+ constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size;
+#pragma unroll
+ for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
+ __dpct_align__(16) sycl::half2 tmp[cpy_ne_D];
+ ggml_sycl_memcpy_1<cpy_ne_D * 4>(tmp, &VKQ_combine[(item_ct1.get_local_id(1) + ip) * (DVp / 2) + i0 +
+ item_ct1.get_local_id(2) * cpy_ne_D]);
+#pragma unroll
+ for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
+ VKQ[i0/warp_size + i1] += tmp[i1];
+ }
+ }
+#else
+ constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
+#pragma unroll
+ for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
+ __dpct_align__(16) float tmp[cpy_ne_D];
+ ggml_sycl_memcpy_1<cpy_ne_D*4>(tmp, &VKQ_combine[(item_ct1.get_local_id(1) + ip)*DVp + i0 + item_ct1.get_local_id(2)*cpy_ne_D]);
+#pragma unroll
+ for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
+ ((float *)VKQ)[i0/warp_size + i1] += tmp[i1];
+ }
+ }
+#endif // SYCL_FAST_FP16
+
+ KQ_sum[0] += KQ_sum_combine[item_ct1.get_local_id(1) + ip];
+ }
+ }
+
+ // Attention sink: adjust KQ max and sum only for the first of all parallel blocks:
+ if (sinks && item_ct1.get_group(1) == 0) {
+#pragma unroll
+ for (int jc0 = 0; jc0 < cpw; ++jc0) {
+ const int jc = jc0 + (item_ct1.get_local_id(1) / np) * cpw;
+ const float sink = ((const float *) sinks)[head0 + jc % ncols2];
+
+ float KQ_max_new_j = sycl::fmax((float) KQ_max[jc0], sink);
+ const float KQ_max_scale = sycl::native::exp((float) (KQ_max[jc0] - KQ_max_new_j));
+ KQ_max[jc0] = KQ_max_new_j;
+
+ const float val = sycl::native::exp((float) (sink - KQ_max[jc0]));
+ KQ_sum[jc0] = KQ_sum[jc0]*KQ_max_scale + val;
+
+#ifdef SYCL_FAST_FP16
+ const sycl::half2 KQ_max_scale_h2 = sycl::half2(KQ_max_scale, KQ_max_scale);
+#pragma unroll
+ for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
+ VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size] *= KQ_max_scale_h2;
+ }
+#else
+#pragma unroll
+ for (int i0 = 0; i0 < DVp/2; i0 += warp_size) {
+ VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].x() *= KQ_max_scale;
+ VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].y() *= KQ_max_scale;
+ }
+#endif // SYCL_FAST_FP16
+ }
+ }
+
+ // Write back results:
+#pragma unroll
+ for (int jc0 = 0; jc0 < cpw; ++jc0) {
+ const int jc = jc0 + (item_ct1.get_local_id(1) / np) * cpw;
+
+ const int j = jc / ncols2;
+ const int c = jc % ncols2;
+
+ if (ncols1 > 1 && col_Q_0 + j >= int(ne01.z())) {
+ return;
+ }
+
+ const float scale = item_ct1.get_group_range(1) == 1 ? 1.0f / KQ_sum[jc0] : 1.0f;
+
+ const int j_dst_unrolled =
+ ((sequence * int(ne01.z()) + col_Q_0 + j) * ne02 + head0 + c) * item_ct1.get_group_range(1) +
+ item_ct1.get_group(1);
+
+#ifdef SYCL_FAST_FP16
+ constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size;
+#pragma unroll
+ for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) {
+ __dpct_align__(16) sycl::float2 tmp[cpy_ne_D];
+#pragma unroll
+ for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
+ tmp[i1] = VKQ[jc0 * ((DVp / 2) / warp_size) + i0 / warp_size + i1]
+ .template convert<float, sycl::rounding_mode::automatic>();
+ tmp[i1].x() *= scale;
+ tmp[i1].y() *= scale;
+ }
+ if (i0 + warp_size * cpy_ne_D <= DV / 2 || i0 + item_ct1.get_local_id(2) * cpy_ne_D < DV / 2) {
+ ggml_sycl_memcpy_1<sizeof(tmp)>(
+ &dst[j_dst_unrolled * DV + 2 * i0 + item_ct1.get_local_id(2) * (2 * cpy_ne_D)], tmp);
+ }
+ }
+#else
+ constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size;
+#pragma unroll
+ for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
+ if (i0 + warp_size*cpy_ne_D <= DV || i0 + item_ct1.get_local_id(2)*cpy_ne_D < DV) {
+#pragma unroll
+ for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) {
+ VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].x() *= scale;
+ VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].y() *= scale;
+ }
+ ggml_sycl_memcpy_1<cpy_ne_D*4>(
+ &dst[j_dst_unrolled*DV + i0 + item_ct1.get_local_id(2)*cpy_ne_D],
+ &VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size)]);
+ }
+ }
+#endif // SYCL_FAST_FP16
+
+ if (item_ct1.get_group_range(1) != 1 && item_ct1.get_local_id(2) == 0) {
+ dst_meta[j_dst_unrolled] = make_float2(KQ_max[jc0], KQ_sum[jc0]);
+ }
+ }
+#else
+ GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
+ max_bias, m0, m1, n_head_log2, logit_softcap,
+ ne00, ne01, ne02, ne03,
+ nb01, nb02, nb03,
+ ne10, ne11, ne12, ne13,
+ nb11, nb12, nb13,
+ nb21, nb22, nb23,
+ ne31, ne32, ne33,
+ nb31, nb32, nb33);
+#endif // SYCL_FLASH_ATTN
+}
+
+template <int DKQ, int DV, int ncols2, bool use_logit_softcap>
+static void launch_fattn_tile_switch_ncols1(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * Q = dst->src[0];
+
+ const int id = ggml_sycl_get_device();
+ const int cc = ggml_sycl_info().devices[id].cc;
+ const int warp_size = WARP_32_SIZE; //can't support WARP_16_SIZE
+
+ constexpr size_t nbytes_shared = 0;
+
+ if constexpr (DV <= 256) {
+ if (Q->ne[1] > 16/ncols2) {
+ constexpr int cols_per_block = 32;
+ const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
+ const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
+ launch_fattn<DV, cols_per_block/ncols2, ncols2,
+ flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
+ (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
+ return;
+ }
+ }
+
+ if (Q->ne[1] > 8/ncols2) {
+ constexpr int cols_per_block = 16;
+ const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
+ const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
+ launch_fattn<DV, cols_per_block/ncols2, ncols2,
+ flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
+ (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
+ return;
+ }
+
+ if constexpr (ncols2 <= 8) {
+ if (Q->ne[1] > 4/ncols2) {
+ constexpr int cols_per_block = 8;
+ const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
+ const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
+ launch_fattn<DV, cols_per_block/ncols2, ncols2,
+ flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
+ (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
+ return;
+ }
+ }
+
+ if constexpr (ncols2 <= 4) {
+ if (Q->ne[1] > 2/ncols2) {
+ constexpr int cols_per_block = 4;
+ const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
+ const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
+ launch_fattn<DV, cols_per_block/ncols2, ncols2,
+ flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
+ (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
+ return;
+ }
+ }
+
+ if constexpr (ncols2 <= 2) {
+ constexpr int cols_per_block = 2;
+ const int nwarps = ggml_sycl_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size;
+ const int nbatch_fa = ggml_sycl_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc);
+ launch_fattn<DV, cols_per_block/ncols2, ncols2,
+ flash_attn_tile<DKQ, DV, cols_per_block / ncols2, ncols2, use_logit_softcap, warp_size>, warp_size>
+ (ctx, dst, nwarps, nbytes_shared, nbatch_fa, true, true, false);
+ return;
+ }
+
+ GGML_ABORT("fatal error");
+}
+
+template <int DKQ, int DV, bool use_logit_softcap>
+static void launch_fattn_tile_switch_ncols2(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * KQV = dst;
+ const ggml_tensor * Q = dst->src[0];
+ const ggml_tensor * K = dst->src[1];
+ const ggml_tensor * mask = dst->src[3];
+
+ float max_bias = 0.0f;
+ memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
+
+ GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
+ const int gqa_ratio = Q->ne[2] / K->ne[2];
+
+ // On NVIDIA (Pascal and older) the GQA optimizations seem to be detrimental in some cases.
+ // However, for DKQ == 576, DV == 512 only the kernel variant with GQA optimizations is implemented.
+ //const bool nvidia = GGML_SYCL_CC_IS_NVIDIA(ggml_sycl_info().devices[ggml_sycl_get_device()].cc);
+ const int gqa_limit = gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX;
+ const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0;
+
+ if constexpr (DV == 512) {
+ if (use_gqa_opt && gqa_ratio % 16 == 0) {
+ launch_fattn_tile_switch_ncols1<DKQ, DV, 16, use_logit_softcap>(ctx, dst);
+ return;
+ }
+ if (use_gqa_opt && gqa_ratio % 4 == 0) {
+ launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
+ return;
+ }
+ }
+
+ if constexpr (DV <= 256) {
+ if (use_gqa_opt && gqa_ratio % 8 == 0) {
+ launch_fattn_tile_switch_ncols1<DKQ, DV, 8, use_logit_softcap>(ctx, dst);
+ return;
+ }
+
+ if (use_gqa_opt && gqa_ratio % 4 == 0) {
+ launch_fattn_tile_switch_ncols1<DKQ, DV, 4, use_logit_softcap>(ctx, dst);
+ return;
+ }
+
+ if (use_gqa_opt && gqa_ratio % 2 == 0) {
+ launch_fattn_tile_switch_ncols1<DKQ, DV, 2, use_logit_softcap>(ctx, dst);
+ return;
+ }
+
+ launch_fattn_tile_switch_ncols1<DKQ, DV, 1, use_logit_softcap>(ctx, dst);
+ return;
+ }
+ GGML_ABORT("fatal error");
+}
+
+template <int DKQ, int DV>
+void ggml_sycl_flash_attn_ext_tile_case(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * KQV = dst;
+
+ float logit_softcap;
+ memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
+ if (logit_softcap == 0.0f) {
+ constexpr bool use_logit_softcap = false;
+ launch_fattn_tile_switch_ncols2<DKQ, DV, use_logit_softcap>(ctx, dst);
+ } else {
+ constexpr bool use_logit_softcap = true;
+ launch_fattn_tile_switch_ncols2<DKQ, DV, use_logit_softcap>(ctx, dst);
+ }
+}
+
+void ggml_sycl_flash_attn_ext_tile(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+
+#define DECL_FATTN_TILE_CASE(DKQ, DV) \
+ template void ggml_sycl_flash_attn_ext_tile_case \
+ <DKQ, DV>(ggml_backend_sycl_context & ctx, ggml_tensor * dst) \
+
+extern DECL_FATTN_TILE_CASE( 40, 40);
+extern DECL_FATTN_TILE_CASE( 64, 64);
+extern DECL_FATTN_TILE_CASE( 72, 72);
+extern DECL_FATTN_TILE_CASE( 80, 80);
+extern DECL_FATTN_TILE_CASE( 96, 96);
+extern DECL_FATTN_TILE_CASE(112, 112);
+extern DECL_FATTN_TILE_CASE(128, 128);
+extern DECL_FATTN_TILE_CASE(256, 256);
+extern DECL_FATTN_TILE_CASE(576, 512);
+
--- /dev/null
+#ifndef GGML_SYCL_FATTN_VEC_HPP
+#define GGML_SYCL_FATTN_VEC_HPP
+
+#include <sycl/sycl.hpp>
+#include <sycl/ext/oneapi/work_group_static.hpp>
+#include <iostream>
+#include <iomanip>
+
+#include "dpct/helper.hpp"
+#include "common.hpp"
+#include "ggml.h"
+#include "fattn-common.hpp"
+#include <cmath>
+#include <float.h>
+
+namespace syclex = sycl::ext::oneapi::experimental;
+
+static int ggml_sycl_fattn_vec_get_nthreads_host(const int cc) {
+ return 128;
+ GGML_UNUSED(cc);
+}
+
+static constexpr int ggml_sycl_fattn_vec_get_nthreads_device() {
+ return 128;
+}
+
+// Currenlty llvm with the amdgcn target dose not support unrolling loops
+// that contain a break that can not be resolved at compile time.
+#ifdef __clang__
+#pragma clang diagnostic push
+#pragma clang diagnostic ignored "-Wpass-failed"
+#endif // __clang__
+
+template <int D,
+ int ncols,
+ int type_K,
+ int type_V,
+ bool use_logit_softcap,
+ int warp_size> // D == head size
+static void flash_attn_ext_vec(const char* __restrict__ Q,
+ const char* __restrict__ K,
+ const char* __restrict__ V,
+ const char* __restrict__ mask,
+ const char* __restrict__ sinks,
+ const int* __restrict__ KV_max,
+ float* __restrict__ dst,
+ sycl::float2* __restrict__ dst_meta,
+ const float scale,
+ const float max_bias,
+ const float m0,
+ const float m1,
+ const uint32_t n_head_log2,
+ const float logit_softcap,
+ const int32_t ne00,
+ const sycl::uint3 ne01,
+ const int32_t ne02,
+ const int32_t ne03,
+ const int32_t nb01,
+ const int32_t nb02,
+ const int32_t nb03,
+ const int32_t ne10,
+ const int32_t ne11,
+ const int32_t ne12,
+ const int32_t ne13,
+ const int32_t nb11,
+ const int32_t nb12,
+ const int64_t nb13,
+ const int32_t nb21,
+ const int32_t nb22,
+ const int64_t nb23,
+ const int32_t ne31,
+ const int32_t ne32,
+ const int32_t ne33,
+ const int32_t nb31,
+ const int32_t nb32,
+ const int64_t nb33) {
+#ifdef SYCL_FLASH_ATTN
+ // Skip unused kernel variants for faster compilation:
+
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+ if (use_logit_softcap && !(D == 128 || D == 256)) {
+ GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
+ max_bias, m0, m1, n_head_log2, logit_softcap,
+ ne00, ne01, ne02, ne03,
+ nb01, nb02, nb03,
+ ne10, ne11, ne12, ne13,
+ nb11, nb12, nb13,
+ nb21, nb22, nb23,
+ ne31, ne32, ne33,
+ nb31, nb32, nb33);
+ return;
+ }
+
+ //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
+
+ constexpr int cpy_nb = ggml_sycl_get_max_cpy_bytes();
+ constexpr int cpy_ne = cpy_nb / 4;
+
+ constexpr int nthreads_KQ_q = (D/4 < warp_size ? D/4 : warp_size);
+ constexpr int nthreads_V_q = (D/4 < warp_size ? D/4 : warp_size);
+
+ constexpr int nthreads = ggml_sycl_fattn_vec_get_nthreads_device();
+ constexpr int nthreads_KQ = type_K == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_KQ_q;
+ constexpr int nthreads_V = type_V == GGML_TYPE_F16 ? 128 / cpy_nb : nthreads_V_q;
+
+ static_assert(warp_size % nthreads_KQ == 0, "bad nthreads_K");
+ static_assert(warp_size % nthreads_V == 0, "bad nthreads_V");
+
+ constexpr int V_rows_per_thread = type_V == GGML_TYPE_F16 ? 2*cpy_ne : 4;
+ constexpr int V_cols_per_iter = warp_size / nthreads_V;
+
+ constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ, warp_size>();
+ constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
+#ifdef GGML_SYCL_F16
+ constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, sycl::half, V_rows_per_thread>();
+#else
+ constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, float, V_rows_per_thread>();
+#endif // GGML_SYCL_F16
+
+ const int ic0 = item_ct1.get_group(2) * ncols; // Index of the Q/QKV column to work on.
+
+ const int sequence = item_ct1.get_group(0) / ne02;
+ const int head = item_ct1.get_group(0) - sequence * ne02;
+ const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
+ Q += nb03*sequence + nb02* head + nb01*ic0;
+ K += nb13*sequence + nb12*(head / gqa_ratio);
+ V += nb23*sequence + nb22*(head / gqa_ratio);
+
+ const sycl::half * maskh = (const sycl::half *) (mask + nb33 * (sequence % ne33) + nb31 * ic0);
+
+ const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
+
+ static_assert(D % (2*warp_size) == 0, "D not divisible by 2*warp_size == 64.");
+ constexpr int nwarps = nthreads / warp_size;
+ const int tid = warp_size * item_ct1.get_local_id(1) + item_ct1.get_local_id(2);
+ __builtin_assume(tid < nthreads);
+
+ constexpr int ne_KQ = ncols*D;
+ constexpr int ne_combine = nwarps*V_cols_per_iter*D;
+
+ constexpr size_t lsm_size1 = ncols * warp_size;
+ constexpr size_t lsm_size2 = ncols * warp_size;
+#ifdef GGML_SYCL_F16
+ sycl::half2 VKQ[ncols][(D / 2) / nthreads_V] = { { { 0.0f, 0.0f } } };
+ constexpr size_t lsm_size3 = (ne_KQ > ne_combine ? ne_KQ : ne_combine);
+ constexpr size_t local_share_mem_size = (lsm_size1 + lsm_size2)*sizeof(float) + lsm_size3*sizeof(sycl::half);
+
+ syclex::work_group_static<char[local_share_mem_size]> lsm;
+
+ float *KQ_max_shared = (float *)&lsm;
+ float *KQ_sum_shared = KQ_max_shared+lsm_size1;
+ sycl::half* KQ = (sycl::half*)(KQ_sum_shared + lsm_size2);
+
+
+#else
+ sycl::float2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
+
+ constexpr size_t lsm_size3 = (ne_KQ > ne_combine ? ne_KQ : ne_combine);
+ constexpr size_t local_share_mem_size = (lsm_size1 + lsm_size2 + lsm_size3)*sizeof(float);
+
+
+ syclex::work_group_static<char[local_share_mem_size]> lsm;
+ float *KQ_max_shared = (float *)&lsm;
+ float *KQ_sum_shared = KQ_max_shared+lsm_size1;
+ float* KQ = KQ_sum_shared + lsm_size2;
+
+#endif // GGML_SYCL_F16
+
+ float KQ_max[ncols];
+ float KQ_sum[ncols];
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ KQ_max[j] = -FLT_MAX/2.0f;
+ KQ_sum[j] = 0.0f;
+ }
+
+ // Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
+#ifdef GGML_SYCL_F16
+ sycl::half2 Q_reg[ncols][(D / 2) / nthreads_KQ] = {{{0.0f, 0.0f}}}; // Will be initialized completely.
+#else
+ sycl::float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized.
+#endif // GGML_SYCL_F16
+ int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
+ sycl::float2 Q_ds[ncols][1 > D / (sizeof(int) * nthreads_KQ) ? 1 : D / (sizeof(int) * nthreads_KQ)];
+ if constexpr (Q_q8_1) {
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+ const int j = j0 + item_ct1.get_local_id(1);
+
+ if (j0 + nwarps > ncols && j >= ncols) {
+ break;
+ }
+
+ // Reuse KQ as temporary storage for converting Q to q8_1:
+ int * tmp_q_i32 = (int *) &KQ[j*D];
+ sycl::float2 * tmp_q_ds = (sycl::float2 *) (tmp_q_i32 + D / sizeof(int));
+
+ // Set memory to zero if out of bounds:
+ if (ncols > 1 && ic0 + j >= int(ne01.z())) {
+#pragma unroll
+ for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += warp_size) {
+ const int i = i0 + item_ct1.get_local_id(2);
+
+ if (i0 + warp_size <= int(D/sizeof(int)) || i < int(D/sizeof(int))) {
+ tmp_q_i32[i] = 0;
+ }
+ }
+ if (item_ct1.get_local_id(2) < D/QK8_1) {
+ tmp_q_ds[item_ct1.get_local_id(2)] = sycl::float2(0.0f, 0.0f);
+ }
+ } else {
+ const float * Q_f = (const float *) (Q + j*nb01);
+ constexpr int nthreads_quantize = D/sizeof(int) < warp_size ? D/sizeof(int) : warp_size;
+#pragma unroll
+ for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_quantize) {
+ quantize_q8_1_to_shared<sycl::float2, nthreads_quantize, warp_size>
+ (Q_f + i0*sizeof(int), scale, tmp_q_i32 + i0, tmp_q_ds + i0/QI8_1);
+ }
+ }
+ }
+
+
+ item_ct1.barrier(sycl::access::fence_space::local_space);
+
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ int * tmp_q_i32 = (int *) &KQ[j*D];
+ sycl::float2 * tmp_q_ds = (sycl::float2 *) (tmp_q_i32 + D / sizeof(int));
+
+#pragma unroll
+ for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += nthreads_KQ) {
+ const int i =
+ i0 + (nthreads_KQ == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_KQ);
+
+ Q_i32[j][i0/nthreads_KQ] = tmp_q_i32[i];
+ Q_ds[j][i0/nthreads_KQ] = tmp_q_ds[i/QI8_1];
+ }
+ }
+
+ item_ct1.barrier(sycl::access::fence_space::local_space);
+
+ } else {
+#ifdef GGML_SYCL_F16
+ const sycl::half2 scale_h2 = sycl::half2(scale, scale);
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ const sycl::float2 * Q_j = (const sycl::float2 *) (Q + j * nb01);
+#pragma unroll
+ for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {
+ const int i = i0 + (nthreads_KQ == warp_size ? item_ct1.get_local_id(2) :
+ item_ct1.get_local_id(2) % nthreads_KQ) *
+ cpy_ne;
+
+ sycl::float2 tmp[cpy_ne] = {
+ { 0.0f, 0.0f }
+ };
+ if (ncols == 1 || ic0 + j < int(ne01.z())) {
+ ggml_sycl_memcpy_1<cpy_nb>(tmp, &Q_j[i]);
+ ggml_sycl_memcpy_1<cpy_nb>(tmp + cpy_ne/2, &Q_j[i + cpy_ne/2]);
+ }
+#pragma unroll
+ for (int i1 = 0; i1 < cpy_ne; ++i1) {
+ Q_reg[j][i0 / nthreads_KQ + i1] = sycl::half2(tmp[i1].x(), tmp[i1].y());
+ }
+ }
+#pragma unroll
+ for (int k = 0; k < (D/2)/nthreads_KQ; ++k) {
+ Q_reg[j][k] *= scale_h2;
+ }
+ }
+#else
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ const sycl::float2 * Q_j = (const sycl::float2 *) (Q + j*nb01);
+#pragma unroll
+ for (int i0 = 0; i0 < D/2; i0 += nthreads_KQ*cpy_ne) {
+ const int i = i0 + (nthreads_KQ == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_KQ)*cpy_ne;
+ if (ncols == 1 || ic0 + j < int(ne01.z())) {
+ ggml_sycl_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ], &Q_j[i]);
+ ggml_sycl_memcpy_1<cpy_nb>(&Q_reg[j][i0/nthreads_KQ + cpy_ne/2], &Q_j[i + cpy_ne/2]);
+ }
+ }
+#pragma unroll
+ for (int k = 0; k < (D/2)/nthreads_KQ; ++k) {
+ Q_reg[j][k].x() *= scale;
+ Q_reg[j][k].y() *= scale;
+ }
+ }
+#endif // GGML_SYCL_F16
+ }
+
+ const int k_VKQ_max = KV_max ? KV_max[sequence * item_ct1.get_group_range(2) + item_ct1.get_group(2)] : ne11;
+ K += item_ct1.get_group(1) * nthreads * nb11;
+ V += item_ct1.get_group(1) * nthreads * nb21;
+ maskh += item_ct1.get_group(1) * nthreads;
+ for (int k_VKQ_0 = item_ct1.get_group(1) * nthreads; k_VKQ_0 < k_VKQ_max;
+ k_VKQ_0 += item_ct1.get_group_range(1) * nthreads,
+ // Increment pointers after each loop:
+ K += item_ct1.get_group_range(1) * nthreads * nb11, V += item_ct1.get_group_range(1) * nthreads * nb21,
+ maskh += item_ct1.get_group_range(1) * nthreads) {
+ // Calculate KQ tile and keep track of new maximum KQ values:
+ float KQ_reg[ncols]={}; // KQ in registers.
+ float KQ_max_new[ncols]={};
+
+
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ KQ_max_new[j] = KQ_max[j];
+ }
+
+#pragma unroll
+ for (int i_KQ_0 = 0; i_KQ_0 < nthreads_KQ; ++i_KQ_0) {
+ const int i_KQ = item_ct1.get_local_id(1) * warp_size +
+ (nthreads_KQ == warp_size ? 0 : (item_ct1.get_local_id(2) & ~(nthreads_KQ - 1))) + i_KQ_0;
+
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ float sum = vec_dot_KQ(K + i_KQ*nb11, Q_reg[j], Q_i32[j], Q_ds[j]);
+ sum = warp_reduce_sum<nthreads_KQ>(sum);
+
+ if (use_logit_softcap) {
+ sum = logit_softcap * sycl::tanh(sum);
+ }
+ if (mask) {
+ sum += slope * sycl::vec<sycl::half, 1>(maskh[j * ne11 + i_KQ])
+ .convert<float, sycl::rounding_mode::automatic>()[0];
+ }
+
+ KQ_max_new[j] = sycl::fmax((float) KQ_max_new[j], sum);
+
+ if (int(nthreads_KQ == warp_size ? item_ct1.get_local_id(2)
+ : item_ct1.get_local_id(2) %
+ nthreads_KQ) == i_KQ_0) {
+ KQ_reg[j] = sum;
+ }
+ }
+ }
+
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+#pragma unroll
+ for (int offset = nthreads_KQ; offset < warp_size; offset <<= 1) {
+ KQ_max_new[j] = sycl::fmax(
+ (float)KQ_max_new[j],
+ (float)dpct::permute_sub_group_by_xor(
+ sycl::ext::oneapi::this_work_item::get_sub_group(),
+ KQ_max_new[j],
+ offset,
+ warp_size));
+ }
+ const float KQ_max_scale = sycl::native::exp((float) (KQ_max[j] - KQ_max_new[j]));
+ KQ_max[j] = KQ_max_new[j];
+
+ KQ_reg[j] = sycl::native::exp((float) (KQ_reg[j] - KQ_max[j]));
+ KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j];
+ KQ[j*nthreads + tid] = KQ_reg[j];
+
+#ifdef GGML_SYCL_F16
+ const sycl::half2 KQ_max_scale_h2 = sycl::half2(KQ_max_scale, KQ_max_scale);
+#pragma unroll
+ for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
+ VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2;
+ }
+#else
+#pragma unroll
+ for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
+ VKQ[j][i_VKQ_0/nthreads_V].x() *= KQ_max_scale;
+ VKQ[j][i_VKQ_0/nthreads_V].y() *= KQ_max_scale;
+ }
+#endif // GGML_SYCL_F16
+ }
+
+ sycl::group_barrier(sycl::ext::oneapi::this_work_item::get_sub_group());
+
+#pragma unroll
+ for (int k0 = 0; k0 < warp_size; k0 += V_cols_per_iter) {
+ const int k = item_ct1.get_local_id(1) * warp_size + k0 +
+ (nthreads_V == warp_size ? 0 : item_ct1.get_local_id(2) / nthreads_V);
+
+#ifdef GGML_SYCL_F16
+ sycl::half2 KQ_k[ncols];
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ KQ_k[j] = sycl::half2(KQ[j * nthreads + k]);
+ }
+#pragma unroll
+ for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
+ sycl::half2 tmp[V_rows_per_thread / 2];
+ dequantize_V(V + k * nb21, tmp,
+ 2 * i_VKQ_0 + (nthreads_V == warp_size ? item_ct1.get_local_id(2) :
+ item_ct1.get_local_id(2) % nthreads_V) *
+ V_rows_per_thread);
+#pragma unroll
+ for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1] += tmp[i_VKQ_1]*KQ_k[j];
+ }
+ }
+ }
+#else
+ float KQ_k[ncols];
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ KQ_k[j] = KQ[j*nthreads + k];
+ }
+#pragma unroll
+ for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
+ sycl::float2 tmp[V_rows_per_thread/2];
+ dequantize_V(V + k*nb21, tmp,
+ 2*i_VKQ_0 + (nthreads_V == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_V)*V_rows_per_thread);
+#pragma unroll
+ for (int i_VKQ_1 = 0; i_VKQ_1 < V_rows_per_thread/2; ++i_VKQ_1) {
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].x() += tmp[i_VKQ_1].x()*KQ_k[j];
+ VKQ[j][i_VKQ_0/nthreads_V + i_VKQ_1].y() += tmp[i_VKQ_1].y()*KQ_k[j];
+ }
+ }
+ }
+#endif // GGML_SYCL_F16
+ }
+ }
+
+ if (sinks && item_ct1.get_group(1) == 0) {
+ const float sink = ((const float *) sinks)[head];
+
+#pragma unroll
+ for (int j0 = 0; j0 < ncols; j0 += nwarps) {
+ const int j = j0 + item_ct1.get_local_id(1);
+
+ if (j0 + nwarps > ncols && j >= ncols) {
+ break;
+ }
+ const float kqmax_new_j = sycl::fmax(sink, (float) KQ_max[j]);
+ const float KQ_max_scale = sycl::native::exp((float) (KQ_max[j] - kqmax_new_j));
+ KQ_max[j] = kqmax_new_j;
+
+ KQ_sum[j] = KQ_sum[j] * KQ_max_scale +
+ (item_ct1.get_local_id(2) == 0 ? sycl::native::exp((float) (sink - KQ_max[j])) : 0.0f);
+#ifdef GGML_SYCL_F16
+ const sycl::half2 KQ_max_scale_h2 = sycl::half2(KQ_max_scale, KQ_max_scale);
+#pragma unroll
+ for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
+ VKQ[j][i_VKQ_0/nthreads_V] *= KQ_max_scale_h2;
+ }
+#else
+#pragma unroll
+ for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
+ VKQ[j][i_VKQ_0/nthreads_V].x() *= KQ_max_scale;
+ VKQ[j][i_VKQ_0/nthreads_V].y() *= KQ_max_scale;
+ }
+#endif // GGML_SYCL_F16
+ }
+ }
+
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ if (item_ct1.get_local_id(1) == 0) {
+ KQ_max_shared[j*warp_size+item_ct1.get_local_id(2)] = -FLT_MAX / 2.0f;
+ KQ_sum_shared[j*warp_size+item_ct1.get_local_id(2)] = 0.0f;
+ }
+ }
+
+ item_ct1.barrier(sycl::access::fence_space::local_space);
+
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ if (item_ct1.get_local_id(2) == 0) {
+ KQ_max_shared[j*warp_size+item_ct1.get_local_id(1)] = KQ_max[j];
+ }
+ }
+
+
+ item_ct1.barrier(sycl::access::fence_space::local_space);
+
+#pragma unroll
+ for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
+ if (ncols > 1 && ic0 + j_VKQ >= int(ne01.z())) {
+ break;
+ }
+
+ float kqmax_new = KQ_max_shared[j_VKQ*warp_size+item_ct1.get_local_id(2)];
+ kqmax_new = warp_reduce_max<warp_size>(kqmax_new);
+ const float kqmax_scale = sycl::native::exp((float) (KQ_max[j_VKQ] - kqmax_new));
+ KQ_max[j_VKQ] = kqmax_new;
+
+#ifdef GGML_SYCL_F16
+ sycl::half2 * VKQ_tmp = (sycl::half2 *) KQ + item_ct1.get_local_id(1) * (V_cols_per_iter * D / 2) +
+ (nthreads_V == warp_size ? 0 : item_ct1.get_local_id(2) / nthreads_V) * (D / 2);
+
+ const sycl::half2 kqmax_scale_h2 = sycl::half2(kqmax_scale, kqmax_scale);
+#pragma unroll
+ for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
+ VKQ[j_VKQ][i_VKQ_0/nthreads_V] *= kqmax_scale_h2;
+ }
+#pragma unroll
+ for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
+ const int i_VKQ =
+ i_VKQ_0 + (nthreads_V == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_V) *
+ (V_rows_per_thread / 2);
+
+ ggml_sycl_memcpy_1<V_rows_per_thread * sizeof(sycl::half)>(VKQ_tmp + i_VKQ,
+ &VKQ[j_VKQ][i_VKQ_0 / nthreads_V]);
+ }
+#else
+ sycl::float2 * VKQ_tmp = (sycl::float2 *) KQ + item_ct1.get_local_id(1)*(V_cols_per_iter*D/2)
+ + (nthreads_V == warp_size ? 0 : item_ct1.get_local_id(2) / nthreads_V)*(D/2);
+#pragma unroll
+ for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
+ VKQ[j_VKQ][i_VKQ_0/nthreads_V].x() *= kqmax_scale;
+ VKQ[j_VKQ][i_VKQ_0/nthreads_V].y() *= kqmax_scale;
+ }
+#pragma unroll
+ for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V*V_rows_per_thread/2) {
+ const int i_VKQ = i_VKQ_0 + (nthreads_V == warp_size ? item_ct1.get_local_id(2) : item_ct1.get_local_id(2) % nthreads_V)*(V_rows_per_thread/2);
+
+ ggml_sycl_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]);
+ ggml_sycl_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ + V_rows_per_thread/4, &VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]);
+ }
+#endif // GGML_SYCL_F16
+
+ KQ_sum[j_VKQ] *= kqmax_scale;
+ KQ_sum[j_VKQ] = warp_reduce_sum<warp_size>(KQ_sum[j_VKQ]);
+ if (item_ct1.get_local_id(2) == 0) {
+ KQ_sum_shared[j_VKQ*warp_size+item_ct1.get_local_id(1)] = KQ_sum[j_VKQ];
+ }
+
+ item_ct1.barrier(sycl::access::fence_space::local_space);
+
+
+ if (nthreads <= D || tid < D) {
+ KQ_sum[j_VKQ] = KQ_sum_shared[j_VKQ*warp_size+item_ct1.get_local_id(2)];
+ KQ_sum[j_VKQ] = warp_reduce_sum<warp_size>(KQ_sum[j_VKQ]);
+
+#pragma unroll
+ for (int i0 = 0; i0 < D; i0 += nthreads) {
+ float dst_val = 0;
+#pragma unroll
+ for (int w = 0; w < nwarps; ++w) {
+#pragma unroll
+ for (int v = 0; v < V_cols_per_iter; ++v) {
+ dst_val += float(KQ[w*V_cols_per_iter*D + v*D + i0 + tid]);
+ }
+ }
+ if (item_ct1.get_group_range(1) == 1) {
+ dst_val /= KQ_sum[j_VKQ];
+ }
+ dst[(((sequence * int(ne01.z()) + ic0 + j_VKQ) * ne02 + head) * item_ct1.get_group_range(1) +
+ item_ct1.get_group(1)) *
+ D +
+ i0 + tid] = dst_val;
+ }
+ }
+
+ if (j_VKQ < ncols-1) {
+ item_ct1.barrier(sycl::access::fence_space::local_space);
+ }
+
+ }
+
+ if (item_ct1.get_group_range(1) != 1 && tid < ncols && (ncols == 1 || ic0 + tid < int(ne01.z()))) {
+ dst_meta[((sequence * int(ne01.z()) + ic0 + tid) * ne02 + head) * item_ct1.get_group_range(1) +
+ item_ct1.get_group(1)] = make_float2(KQ_max[tid], KQ_sum[tid]);
+ }
+#else
+ GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
+ max_bias, m0, m1, n_head_log2, logit_softcap,
+ ne00, ne01, ne02, ne03,
+ nb01, nb02, nb03,
+ ne10, ne11, ne12, ne13,
+ nb11, nb12, nb13,
+ nb21, nb22, nb23,
+ ne31, ne32, ne33,
+ nb31, nb32, nb33);
+
+#endif // SYCL_FLASH_ATTN
+}
+#ifdef __clang__
+#pragma clang diagnostic pop
+#endif // __clang__
+
+
+template <int D, int cols_per_block, int type_K, int type_V, bool use_logit_softcap>
+void ggml_sycl_flash_attn_ext_vec_case_impl(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+
+ const int warp_size = WARP_16_SIZE; //better performance than WARP_32_SIZE
+
+ const int cc = ggml_sycl_info().devices[ggml_sycl_get_device()].cc;
+
+ const int nthreads = ggml_sycl_fattn_vec_get_nthreads_host(cc);
+ const int nwarps = nthreads / warp_size;
+
+ const bool need_f16_K = type_K == GGML_TYPE_F16;
+ const bool need_f16_V = type_V == GGML_TYPE_F16;
+ constexpr size_t nbytes_shared = 0;
+
+ launch_fattn<D, cols_per_block, 1,
+ flash_attn_ext_vec<D, cols_per_block, type_K, type_V,
+ use_logit_softcap, warp_size>, warp_size>(
+ ctx, dst, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
+}
+
+template <int D, int type_K, int type_V>
+void ggml_sycl_flash_attn_ext_vec_case(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * KQV = dst;
+ const ggml_tensor * Q = dst->src[0];
+
+ float logit_softcap;
+ memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
+
+ if (Q->ne[1] == 1) {
+ constexpr int cols_per_block = 1;
+ if (logit_softcap == 0.0f) {
+ constexpr bool use_logit_softcap = false;
+ ggml_sycl_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
+ } else {
+ constexpr bool use_logit_softcap = true;
+ ggml_sycl_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
+ }
+ return;
+ }
+
+ constexpr int cols_per_block = 2;
+ if (logit_softcap == 0.0f) {
+ constexpr bool use_logit_softcap = false;
+ ggml_sycl_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
+ } else {
+ constexpr bool use_logit_softcap = true;
+ ggml_sycl_flash_attn_ext_vec_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
+ }
+}
+
+#define DECL_FATTN_VEC_CASE(D, type_K, type_V) \
+ template void ggml_sycl_flash_attn_ext_vec_case \
+ <D, type_K, type_V>(ggml_backend_sycl_context & ctx, ggml_tensor * dst) \
+
+#define EXTERN_DECL_FATTN_VEC_CASES(D, type_K) \
+ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_F16); \
+ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_0); \
+ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q4_1); \
+ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_0); \
+ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q5_1); \
+ extern DECL_FATTN_VEC_CASE(D, type_K, GGML_TYPE_Q8_0); \
+
+EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_F16)
+EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_0)
+EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q4_1)
+EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_0)
+EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q5_1)
+EXTERN_DECL_FATTN_VEC_CASES( 64, GGML_TYPE_Q8_0)
+
+EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_F16)
+EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_0)
+EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q4_1)
+EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_0)
+EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q5_1)
+EXTERN_DECL_FATTN_VEC_CASES(128, GGML_TYPE_Q8_0)
+
+EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_F16)
+EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_0)
+EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q4_1)
+EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_0)
+EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q5_1)
+EXTERN_DECL_FATTN_VEC_CASES(256, GGML_TYPE_Q8_0)
+
+#endif // GGML_SYCL_FATTN_VEC_HPP
--- /dev/null
+//
+// MIT license
+// Copyright (C) 2025 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+
+#include <sycl/sycl.hpp>
+#include "dpct/helper.hpp"
+#include "common.hpp"
+#include "fattn-common.hpp"
+#include "fattn-tile.hpp"
+#include "fattn-vec.hpp"
+#include "fattn.hpp"
+
+
+#define FATTN_VEC_CASE(D, type_K, type_V) \
+ { \
+ const bool type_K_okay = K->type == (type_K) || (K->type == GGML_TYPE_F32 && (type_K) == GGML_TYPE_F16); \
+ const bool type_V_okay = V->type == (type_V) || (V->type == GGML_TYPE_F32 && (type_V) == GGML_TYPE_F16); \
+ if (Q->ne[0] == (D) && type_K_okay && type_V_okay) { \
+ ggml_sycl_flash_attn_ext_vec_case<D, type_K, type_V>(ctx, dst); \
+ return; \
+ } \
+ } \
+
+#define FATTN_VEC_CASES_ALL_D(type_K, type_V) \
+ FATTN_VEC_CASE( 64, type_K, type_V) \
+ FATTN_VEC_CASE(128, type_K, type_V) \
+ FATTN_VEC_CASE(256, type_K, type_V) \
+
+static void ggml_sycl_flash_attn_ext_vec(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+ ggml_tensor * Q = dst->src[0];
+ ggml_tensor * K = dst->src[1];
+ ggml_tensor * V = dst->src[2];
+
+#ifdef GGML_SYCL_FA_ALL_QUANTS
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_F16)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_F16)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_F16)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_F16)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_F16)
+
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_0)
+
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q4_1)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_1)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q4_1)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q4_1)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q4_1)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q4_1)
+
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_0)
+
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q5_1)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q5_1)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q5_1)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q5_1)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q5_1)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q5_1)
+
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_Q8_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q8_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_1, GGML_TYPE_Q8_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_0, GGML_TYPE_Q8_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q5_1, GGML_TYPE_Q8_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
+#else
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_F16, GGML_TYPE_F16)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0)
+ FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_Q8_0)
+#endif // GGML_SYCL_FA_ALL_QUANTS
+
+ GGML_ABORT("Not match KV type in vec");
+}
+
+// Best FlashAttention kernel for a specific GPU:
+enum best_fattn_kernel {
+ BEST_FATTN_KERNEL_NONE = 0,
+ BEST_FATTN_KERNEL_VEC = 100,
+ BEST_FATTN_KERNEL_TILE = 200,
+};
+
+static best_fattn_kernel ggml_sycl_get_best_fattn_kernel(const int device, const ggml_tensor * dst) {
+ GGML_UNUSED(device);
+#ifndef SYCL_FLASH_ATTN
+ GGML_UNUSED(dst);
+ return BEST_FATTN_KERNEL_NONE;
+#endif// SYCL_FLASH_ATTN
+
+ if(!g_ggml_sycl_enable_flash_attention) return BEST_FATTN_KERNEL_NONE;
+
+ const ggml_tensor * KQV = dst;
+ const ggml_tensor * Q = dst->src[0];
+ const ggml_tensor * K = dst->src[1];
+ const ggml_tensor * V = dst->src[2];
+ const ggml_tensor * mask = dst->src[3];
+
+ const int gqa_ratio = Q->ne[2] / K->ne[2];
+ GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
+
+ float max_bias = 0.0f;
+ memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
+
+ bool gqa_opt_applies = gqa_ratio >= 2 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
+ for (const ggml_tensor * t : {Q, K, V, mask}) {
+ if (t == nullptr || ggml_is_quantized(t->type)) {
+ continue;
+ }
+ for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
+ if (t->nb[i] % 16 != 0) {
+ gqa_opt_applies = false;
+ break;
+ }
+ }
+ }
+
+ switch (K->ne[0]) {
+ case 40:
+ case 64:
+ case 72:
+ case 80:
+ case 96:
+ case 128:
+ case 112:
+ case 256:
+ if (V->ne[0] != K->ne[0]) {
+ return BEST_FATTN_KERNEL_NONE;
+ }
+ break;
+ case 576:
+ if (V->ne[0] != 512) {
+ return BEST_FATTN_KERNEL_NONE;
+ }
+ if (!gqa_opt_applies) {
+ return BEST_FATTN_KERNEL_NONE;
+ }
+ break;
+ default:
+ return BEST_FATTN_KERNEL_NONE;
+ }
+
+#ifndef GGML_SYCL_FA_ALL_QUANTS
+ if (K->type != V->type) {
+ return BEST_FATTN_KERNEL_NONE;
+ }
+#endif // GGML_SYCL_FA_ALL_QUANTS
+
+ switch (K->type) {
+ case GGML_TYPE_F32:
+ case GGML_TYPE_F16:
+ break;
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+#ifndef GGML_SYCL_FA_ALL_QUANTS
+ return BEST_FATTN_KERNEL_NONE;
+#endif // GGML_SYCL_FA_ALL_QUANTS
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q8_0:
+ break;
+ default:
+ return BEST_FATTN_KERNEL_NONE;
+ }
+
+ if (mask && mask->ne[2] != 1) {
+ return BEST_FATTN_KERNEL_NONE;
+ }
+
+ // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes:
+ const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
+
+ // Todo: Use the XMX kernel if possible:
+
+ // If there are no tensor cores available, use the generic tile kernel:
+ if (can_use_vector_kernel) {
+ if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) {
+ if (Q->ne[1] == 1) {
+ if (!gqa_opt_applies) {
+ return BEST_FATTN_KERNEL_VEC;
+ }
+ }
+ } else {
+ if (Q->ne[1] <= 2) {
+ return BEST_FATTN_KERNEL_VEC;
+ }
+ }
+ }
+ return BEST_FATTN_KERNEL_TILE;
+}
+
+void ggml_sycl_flash_attn_ext(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+ ggml_sycl_set_device(ctx.device);
+ switch (ggml_sycl_get_best_fattn_kernel(ggml_sycl_get_device(), dst)) {
+ case BEST_FATTN_KERNEL_NONE:
+ GGML_ABORT("Not support Flash-Attention");
+ case BEST_FATTN_KERNEL_TILE:
+ ggml_sycl_flash_attn_ext_tile(ctx, dst);
+ break;
+ case BEST_FATTN_KERNEL_VEC:
+ ggml_sycl_flash_attn_ext_vec(ctx, dst);
+ break;
+ }
+}
+
+bool ggml_sycl_flash_attn_ext_supported(int device, const ggml_tensor * dst) {
+ return ggml_sycl_get_best_fattn_kernel(device, dst) != BEST_FATTN_KERNEL_NONE;
+}
--- /dev/null
+//
+// MIT license
+// Copyright (C) 2025 Intel Corporation
+// SPDX-License-Identifier: MIT
+//
+
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+
+#ifndef GGML_SYCL_FATTN_HPP
+#define GGML_SYCL_FATTN_HPP
+
+#include "common.hpp"
+
+void ggml_sycl_flash_attn_ext(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
+
+bool ggml_sycl_flash_attn_ext_supported(int device, const ggml_tensor * dst);
+
+#endif // GGML_SYCL_FATTN_HPP
int g_ggml_sycl_disable_dnn = 0;
int g_ggml_sycl_prioritize_dmmv = 0;
int g_ggml_sycl_use_async_mem_op = 0;
+int g_ggml_sycl_enable_flash_attention = 1;
+
static ggml_sycl_device_info ggml_sycl_init() {
ggml_sycl_device_info info = {};
info.devices[i].cc =
100 * prop.get_major_version() + 10 * prop.get_minor_version();
- info.devices[i].nsm = prop.get_max_compute_units();
+ info.devices[i].nsm = prop.get_max_compute_units() / 16; //16: Number of Xe Cores
info.devices[i].opt_feature.reorder = device.ext_oneapi_architecture_is(syclex::arch_category::intel_gpu);
info.devices[i].smpbo = prop.get_local_mem_size();
-
info.max_work_group_sizes[i] = prop.get_max_work_group_size();
+ info.devices[i].max_wg_per_cu = info.max_work_group_sizes[i] / prop.get_max_compute_units();
+
}
for (int id = 0; id < info.device_count; ++id) {
g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0);
g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
+
+#ifdef SYCL_FLASH_ATTN
+ g_ggml_sycl_enable_flash_attention = get_sycl_env("GGML_SYCL_ENABLE_FLASH_ATTN", 1);
+#else
+ g_ggml_sycl_enable_flash_attention = 0;
+#endif
+
GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
+
+ GGML_LOG_INFO("Build with Macros:\n");
+#if defined(GGML_SYCL_FORCE_MMQ)
+ GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n");
+#else
+ GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: no\n");
+#endif
+#if defined(GGML_SYCL_F16)
+ GGML_LOG_INFO(" GGML_SYCL_F16: yes\n");
+#else
+ GGML_LOG_INFO(" GGML_SYCL_F16: no\n");
+#endif
+#if defined(GGML_SYCL_GRAPH)
+ GGML_LOG_INFO(" GGML_SYCL_GRAPH: yes\n");
+#else
+ GGML_LOG_INFO(" GGML_SYCL_GRAPH: no\n");
+#endif
+#if defined(GGML_SYCL_DNNL)
+ GGML_LOG_INFO(" GGML_SYCL_DNNL: yes\n");
+#else
+ GGML_LOG_INFO(" GGML_SYCL_DNNL: no\n");
+#endif
+
GGML_LOG_INFO("Running with Environment Variables:\n");
GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n");
#endif
GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv);
- GGML_LOG_INFO("Build with Macros:\n");
-#if defined(GGML_SYCL_FORCE_MMQ)
- GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: yes\n");
-#else
- GGML_LOG_INFO(" GGML_SYCL_FORCE_MMQ: no\n");
-#endif
-#if defined(GGML_SYCL_F16)
- GGML_LOG_INFO(" GGML_SYCL_F16: yes\n");
+
+#ifdef SYCL_FLASH_ATTN
+ GGML_LOG_INFO(" GGML_SYCL_ENABLE_FLASH_ATTN: %d\n", g_ggml_sycl_enable_flash_attention);
#else
- GGML_LOG_INFO(" GGML_SYCL_F16: no\n");
+ GGML_LOG_INFO(" GGML_SYCL_ENABLE_FLASH_ATTN: %d disabled by compile flag\n",
+ g_ggml_sycl_enable_flash_attention);
#endif
/* NOT REMOVE, keep it for next optimize for XMX.
}
#if GGML_SYCL_DNNL
- // oneDNN handles strided data and does not need overhead of get_to_fp16_nc_sycl
+ // oneDNN handles strided data and does not need overhead of ggml_get_to_fp16_nc_sycl
const int64_t ne_src1 = src1->nb[last_str] * src1->ne[last_dim] / type_size_src1;
src1_f16_alloc.alloc(ne_src1);
const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type, dst);
# else
const int64_t ne_src1 = ggml_nelements(src1);
src1_f16_alloc.alloc(ne_src1);
- const to_fp16_nc_sycl_t to_fp16_nc_sycl = get_to_fp16_nc_sycl(src1->type);
+ const to_fp16_nc_sycl_t to_fp16_nc_sycl = ggml_get_to_fp16_nc_sycl(src1->type);
GGML_ASSERT(to_fp16_nc_sycl != nullptr);
to_fp16_nc_sycl(src1_f16, src1_f16_alloc.get(), ne10, ne11, ne12, ne13, s11, s12, s13, queue);
#endif
case GGML_OP_ARANGE:
ggml_sycl_arange(ctx, dst);
break;
+ case GGML_OP_FLASH_ATTN_EXT:
+ ggml_sycl_flash_attn_ext(ctx, dst);
+ break;
default:
return false;
}
return op->type == GGML_TYPE_F32;
case GGML_OP_ARANGE:
return op->type == GGML_TYPE_F32;
+ case GGML_OP_FLASH_ATTN_EXT:
+ return ggml_sycl_flash_attn_ext_supported(device, op);
default:
return false;
}
#define MUL_MAT_SRC1_COL_STRIDE 128
#define QK_WARP_SIZE 32
+#define WARP_32_SIZE 32
+#define WARP_16_SIZE 16
+
#endif // GGML_SYCL_PRESETS_HPP
max_val = sycl::max(max_val, val);
}
// find the max value in the block
- max_val = warp_reduce_max(max_val);
+ max_val = warp_reduce_max<WARP_SIZE>(max_val);
if (block_size > WARP_SIZE) {
if (warp_id == 0) {
item_ct1.barrier();
max_val = buf_iw[lane_id];
- max_val = warp_reduce_max(max_val);
+ max_val = warp_reduce_max<WARP_SIZE>(max_val);
}
float tmp = 0.0f; // partial sum
vals[col] = val;
}
// find the sum of exps in the block
- tmp = warp_reduce_sum(tmp);
+ tmp = warp_reduce_sum<WARP_SIZE>(tmp);
if (block_size > WARP_SIZE) {
item_ct1.barrier();
if (warp_id == 0) {
for (size_t i = 1; i < nreduce; i += 1) {
tmp += buf_iw[lane_id + i * WARP_SIZE];
}
- tmp = warp_reduce_sum(tmp);
+ tmp = warp_reduce_sum<WARP_SIZE>(tmp);
}
if (sinks) {
tmp += sycl::native::exp(sinks[i02] - max_val);
dgf_dot += dstf[col]*grad[col];
}
- dgf_dot = warp_reduce_sum(dgf_dot);
+ dgf_dot = warp_reduce_sum<WARP_SIZE>(dgf_dot);
for (int col = tid; col < ncols; col += WARP_SIZE) {
dst[col] = scale * (grad[col] - dgf_dot) * dstf[col];
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.hpp"
+
+DECL_FATTN_TILE_CASE(112, 112);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.hpp"
+
+DECL_FATTN_TILE_CASE(128, 128);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.hpp"
+
+DECL_FATTN_TILE_CASE(256, 256);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.hpp"
+
+DECL_FATTN_TILE_CASE(40, 40);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.hpp"
+
+DECL_FATTN_TILE_CASE(576, 512);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.hpp"
+
+DECL_FATTN_TILE_CASE(64, 64);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.hpp"
+
+DECL_FATTN_TILE_CASE(72, 72);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.hpp"
+
+DECL_FATTN_TILE_CASE(80, 80);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-tile.hpp"
+
+DECL_FATTN_TILE_CASE(96, 96);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_0);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q4_1);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_0);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q5_1);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_F16, GGML_TYPE_Q8_0);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_F16);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_F16);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_F16);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_F16);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_F16);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1);
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../fattn-vec.hpp"
+
+DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
+DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0);
return d8_0*d8_1 * sumi;
}
+template <typename T, int vdr>
+static __dpct_inline__ T vec_dot_q8_0_q8_1_impl(const int * v, const int * u, const T & d8_0, const T & d8_1) {
+ int sumi = 0;
+
+#pragma unroll
+ for (int i = 0; i < vdr; ++i) {
+ // SIMD dot product of quantized values
+ sumi = ggml_sycl_dp4a(v[i], u[i], sumi);
+ }
+
+ return d8_0*d8_1 * ((T) sumi);
+}
+
template <int vdr>
static __dpct_inline__ float vec_dot_q8_1_q8_1_impl(const int *v, const int *u,
const sycl::half2 &dm8,