#include "ggml-impl.h"
#include "ggml-sycl.h"
#include "presets.hpp"
+#include "type.hpp"
#include "sycl_hw.hpp"
namespace syclexp = sycl::ext::oneapi::experimental;
return val;
}
+static __dpct_inline__ float ggml_sycl_ue4m3_to_fp32(uint8_t x) {
+ const uint32_t bits = x * (x != 0x7F && x != 0xFF);
+ const __nv_fp8_e4m3 xf = *reinterpret_cast<const __nv_fp8_e4m3 *>(&bits);
+ return static_cast<float>(xf) / 2;
+}
+
#endif // GGML_SYCL_COMMON_HPP
});
}
+template <typename dst_t>
+static void dequantize_row_nvfp4_sycl(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {
+ GGML_ASSERT(k % QK_NVFP4 == 0);
+ const int nb = k / QK_NVFP4;
+ stream->parallel_for(
+ sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
+ [=](sycl::nd_item<3> item_ct1) {
+ dequantize_block_nvfp4(vx, y, k);
+ });
+}
+
+
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static void dequantize_block_nc(const void * __restrict__ vx, dst_t * __restrict__ y,
const int64_t ne00, const int64_t ne01, const int64_t ne02,
return dequantize_row_iq4_nl_sycl;
case GGML_TYPE_MXFP4:
return dequantize_row_mxfp4_sycl;
+ case GGML_TYPE_NVFP4:
+ return dequantize_row_nvfp4_sycl;
case GGML_TYPE_F32:
return convert_unary_sycl<float>;
#ifdef GGML_SYCL_HAS_BF16
return convert_unary_sycl<sycl::ext::oneapi::bfloat16>;
#endif
default:
+ GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(type));
return nullptr;
}
}
return dequantize_row_iq4_nl_sycl;
case GGML_TYPE_MXFP4:
return dequantize_row_mxfp4_sycl;
+ case GGML_TYPE_NVFP4:
+ return dequantize_row_nvfp4_sycl;
case GGML_TYPE_F16:
return convert_unary_sycl<sycl::half>;
#ifdef GGML_SYCL_HAS_BF16
return convert_unary_sycl<sycl::ext::oneapi::bfloat16>;
#endif
default:
+ GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(type));
return nullptr;
}
}
}
}
+
+template <typename dst_t>
+static void dequantize_block_nvfp4(
+ const void * __restrict__ vx,
+ dst_t * __restrict__ yy,
+ const int64_t ne) {
+ auto item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+ const int64_t i = item_ct1.get_group(2);
+ const int tid = item_ct1.get_local_id(2);
+
+ const int64_t base = i * QK_NVFP4;
+ if (base >= ne) {
+ return;
+ }
+
+ const block_nvfp4 * x = (const block_nvfp4 *) vx;
+ const block_nvfp4 & xb = x[i];
+
+ const int sub = tid / (QK_NVFP4_SUB / 2);
+ const int j = tid % (QK_NVFP4_SUB / 2);
+
+ const float d = ggml_sycl_ue4m3_to_fp32(xb.d[sub]);
+ const uint8_t q = xb.qs[sub * (QK_NVFP4_SUB / 2) + j];
+
+ const int64_t y0 = base + sub * QK_NVFP4_SUB + j;
+ const int64_t y1 = y0 + QK_NVFP4_SUB / 2;
+
+ yy[y0] = ggml_sycl_cast<dst_t>(d * kvalues_mxfp4[q & 0x0F]);
+ yy[y1] = ggml_sycl_cast<dst_t>(d * kvalues_mxfp4[q >> 4]);
+}
+
+
#endif // GGML_SYCL_DEQUANTIZE_HPP
}
}
+static void mul_mat_vec_nvfp4_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows,
+ dpct::queue_ptr stream) {
+ GGML_ASSERT(ncols % QK_NVFP4 == 0);
+ const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+ const sycl::range<3> block_nums(1, 1, block_num_y);
+ const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+
+ {
+ stream->submit([&](sycl::handler & cgh) {
+ cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
+ [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
+ mul_mat_vec_q<QK_NVFP4, QI_NVFP4, block_nvfp4, VDR_NVFP4_Q8_1_MMVQ, vec_dot_nvfp4_q8_1>(
+ vx, vy, dst, ncols, nrows, item_ct1);
+ });
+ });
+ }
+}
static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
float *dst, const int ncols,
case GGML_TYPE_MXFP4:
mul_mat_vec_mxfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
break;
+ case GGML_TYPE_NVFP4:
+ mul_mat_vec_nvfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+ break;
default:
- GGML_ABORT("fatal error");
+ GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(src0->type));
}
}
GGML_UNUSED(src1);
--- /dev/null
+#pragma once
+
+#include <sycl/sycl.hpp>
+#include <cstdint>
+#include <limits>
+
+inline uint8_t float_to_e4m3(float f)
+{
+ if (sycl::isnan(f)) {
+ return 0x7F; // Canonical NaN (positive)
+ }
+
+ uint32_t bits = sycl::bit_cast<uint32_t>(f);
+ uint32_t sign = (bits >> 31) & 0x1u;
+ uint32_t exp = (bits >> 23) & 0xFFu;
+ uint32_t mant = bits & 0x7FFFFFu;
+
+ // Zero
+ if (exp == 0 && mant == 0) {
+ return static_cast<uint8_t>(sign << 7);
+ }
+
+ // Extract biased exponent and mantissa for FP8
+ int e = static_cast<int>(exp) - 127; // true exponent (IEEE bias 127)
+ uint32_t m = mant;
+
+ // Handle very large values → NaN (NVIDIA behavior for E4M3)
+ if (e > 7) { // max exponent for E4M3 is 7 (biased 14)
+ return static_cast<uint8_t>((sign << 7) | 0x7F);
+ }
+
+ // Handle subnormals and normal numbers
+ if (e < -6) { // smallest normal exponent is -6
+ // Subnormal in FP8: shift mantissa right
+ int shift = -6 - e;
+ m = (m | 0x800000u) >> (shift + 1); // +1 because we lose the implicit 1 position
+ if (shift > 23) m = 0;
+ } else {
+ // Normal number: adjust exponent bias from 127 to 7
+ int new_exp = e + 7;
+ m = (m >> 20) & 0x7u; // take top 3 mantissa bits (after implicit 1)
+ m |= (static_cast<uint32_t>(new_exp) << 3);
+ }
+
+ // Round-to-nearest-even (simple guard + round bit)
+ // For better accuracy you can add sticky bit, but this is sufficient for most use cases
+ uint32_t round_bit = (mant >> 19) & 0x1u; // bit after the 3 mantissa bits
+ if (round_bit) {
+ m += 1;
+ // Carry into exponent if mantissa overflows
+ if ((m & 0x8u) != 0) {
+ m = (m & 0x7u) | ((m & 0x38u) << 1); // simple carry handling
+ // If exponent overflows after carry → NaN
+ if ((m >> 3) > 14) {
+ return static_cast<uint8_t>((sign << 7) | 0x7F);
+ }
+ }
+ }
+
+ uint8_t result = static_cast<uint8_t>((sign << 7) | (m & 0x7F));
+ return result;
+}
+
+inline float e4m3_to_float(uint8_t x)
+{
+ if (x == 0) return 0.0f;
+
+ uint8_t sign = (x >> 7) & 0x1u;
+ uint8_t exp = (x >> 3) & 0xFu;
+ uint8_t mant = x & 0x7u;
+
+ // NaN (NVIDIA uses 0x7F / 0xFF as NaN)
+ if (exp == 0xF && mant != 0) {
+ return std::numeric_limits<float>::quiet_NaN();
+ }
+ if (exp == 0xF) { // 0x7F or 0xFF treated as NaN
+ return std::numeric_limits<float>::quiet_NaN();
+ }
+
+ float val;
+
+ if (exp == 0) {
+ // Subnormal
+ val = mant * (1.0f / 8.0f) * sycl::pow(2.0f, -6.0f);
+ } else {
+ // Normal: implicit leading 1 + bias 7
+ val = (1.0f + mant / 8.0f) * sycl::pow(2.0f, static_cast<float>(exp) - 7.0f);
+ }
+
+ return sign ? -val : val;
+}
+
+// The actual type definition
+struct __nv_fp8_e4m3 {
+ uint8_t raw;
+
+ __nv_fp8_e4m3() = default;
+
+ explicit __nv_fp8_e4m3(float f) : raw(float_to_e4m3(f)) {}
+ explicit __nv_fp8_e4m3(sycl::half h) : raw(float_to_e4m3(static_cast<float>(h))) {}
+
+ operator float() const { return e4m3_to_float(raw); }
+ operator sycl::half() const { return static_cast<sycl::half>(static_cast<float>(*this)); }
+
+ // Allow direct access for vector loads/stores
+ operator uint8_t&() { return raw; }
+ operator uint8_t() const { return raw; }
+};
+
+using __nv_fp8x2_e4m3 = sycl::vec<__nv_fp8_e4m3, 2>;
+using __nv_fp8x4_e4m3 = sycl::vec<__nv_fp8_e4m3, 4>;
+
#include "dpct/helper.hpp"
#include "ggml.h"
+#include "type.hpp"
#include "quants.hpp"
typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1,
return x32;
}
+static __dpct_inline__ int get_int_b2(const void * x, const int & i32) {
+ const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment
+
+ int x32 = x16[2*i32 + 0] << 0;
+ x32 |= x16[2*i32 + 1] << 16;
+
+ return x32;
+}
+
+static __dpct_inline__ int get_int_b4(const void * x, const int & i32) {
+ return ((const int *) x)[i32]; // assume at least 4 byte alignment
+}
static __dpct_inline__ int get_int_from_int8(const int8_t* x8, const int& i32) {
const uint16_t* x16 =
return d * sumi;
}
+#define VDR_NVFP4_Q8_1_MMVQ 4
+#define VDR_NVFP4_Q8_1_MMQ 8
+
+static __dpct_inline__ float vec_dot_nvfp4_q8_1(const void * __restrict__ vbq,
+ const block_q8_1 * __restrict__ bq8_1,
+ const int32_t & iqs) {
+ const block_nvfp4 * bq4 = (const block_nvfp4 *) vbq;
+ float sum = 0.0f;
+#pragma unroll
+ for (int i = 0; i < VDR_NVFP4_Q8_1_MMVQ/2; i++) {
+ const int32_t iqs0 = iqs + 2*i;
+ const int32_t iqs1 = iqs0 + 1;
+ const int32_t is = iqs0 >> 1;
+ const sycl::int2 v0 = get_int_from_table_16(get_int_b4(bq4->qs, iqs0), kvalues_mxfp4);
+ const sycl::int2 v1 = get_int_from_table_16(get_int_b4(bq4->qs, iqs1), kvalues_mxfp4);
+ const block_q8_1 * bq8 = bq8_1 + (is >> 1);
+ const int32_t i8 = ((is & 1) << 2);
+
+ int sumi = ggml_sycl_dp4a(v0.x(), get_int_b4(bq8->qs, i8 + 0), 0);
+ sumi = ggml_sycl_dp4a(v0.y(), get_int_b4(bq8->qs, i8 + 2), sumi);
+ sumi = ggml_sycl_dp4a(v1.x(), get_int_b4(bq8->qs, i8 + 1), sumi);
+ sumi = ggml_sycl_dp4a(v1.y(), get_int_b4(bq8->qs, i8 + 3), sumi);
+
+ const float d = ggml_sycl_ue4m3_to_fp32(bq4->d[is]) * (bq8->ds)[0];
+ sum += d * float(sumi);
+ }
+
+ return sum;
+}
static __dpct_inline__ float
vec_dot_q5_0_q8_1(const void *__restrict__ vbq,