#include "dequantize.hpp"
#include "presets.hpp"
+#if defined(__INTEL_LLVM_COMPILER)
+ #if __has_include(<sycl/ext/oneapi/bfloat16.hpp>)
+ #include <sycl/ext/oneapi/bfloat16.hpp>
+ #define GGML_SYCL_HAS_BF16
+ #endif
+#endif
+
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k,
const sycl::nd_item<3> &item_ct1) {
return dequantize_row_iq4_nl_sycl;
case GGML_TYPE_F32:
return convert_unary_sycl<float>;
+#ifdef GGML_SYCL_HAS_BF16
+ case GGML_TYPE_BF16:
+ return convert_unary_sycl<sycl::ext::oneapi::bfloat16>;
+#endif
default:
return nullptr;
}
return dequantize_row_iq4_nl_sycl;
case GGML_TYPE_F16:
return convert_unary_sycl<sycl::half>;
+#ifdef GGML_SYCL_HAS_BF16
+ case GGML_TYPE_BF16:
+ return convert_unary_sycl<sycl::ext::oneapi::bfloat16>;
+#endif
default:
return nullptr;
}
switch (type) {
case GGML_TYPE_F32:
return convert_unary_nc_sycl<float>;
+#ifdef GGML_SYCL_HAS_BF16
+ case GGML_TYPE_BF16:
+ return convert_unary_nc_sycl<sycl::ext::oneapi::bfloat16>;
+#endif
default:
return nullptr;
}