]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
sycl : support nvfp4 type in mul_mat (llama/21227)
authorNeo Zhang <redacted>
Wed, 1 Apr 2026 10:54:15 +0000 (18:54 +0800)
committerGeorgi Gerganov <redacted>
Thu, 2 Apr 2026 07:25:32 +0000 (10:25 +0300)
src/ggml-sycl/common.hpp
src/ggml-sycl/convert.cpp
src/ggml-sycl/dequantize.hpp
src/ggml-sycl/mmvq.cpp
src/ggml-sycl/type.hpp [new file with mode: 0644]
src/ggml-sycl/vecdotq.hpp

index fcb0db99c6b28990b2f02949b52e91c2e5bddadd..fd84c9178530f78b107aa17abb4e5e817a296d87 100644 (file)
@@ -23,6 +23,7 @@
 #include "ggml-impl.h"
 #include "ggml-sycl.h"
 #include "presets.hpp"
+#include "type.hpp"
 #include "sycl_hw.hpp"
 
 namespace syclexp = sycl::ext::oneapi::experimental;
@@ -965,4 +966,10 @@ static T block_reduce(T val, T * shared_vals, int block_size_template) {
     return val;
 }
 
+static __dpct_inline__ float ggml_sycl_ue4m3_to_fp32(uint8_t x) {
+    const uint32_t bits = x * (x != 0x7F && x != 0xFF);
+    const __nv_fp8_e4m3 xf = *reinterpret_cast<const __nv_fp8_e4m3 *>(&bits);
+    return static_cast<float>(xf) / 2;
+}
+
 #endif // GGML_SYCL_COMMON_HPP
index d17aca2cac4eac239965805f802879fbf5b857b9..d7f60cbc9ea9d7aad7c25c220f079ac65735e581 100644 (file)
@@ -482,6 +482,18 @@ static void dequantize_row_mxfp4_sycl(const void * vx, dst_t * y, const int64_t
         });
 }
 
+template <typename dst_t>
+static void dequantize_row_nvfp4_sycl(const void * vx, dst_t * y, const int64_t k, dpct::queue_ptr stream) {
+    GGML_ASSERT(k % QK_NVFP4 == 0);
+    const int nb = k / QK_NVFP4;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, nb) * sycl::range<3>(1, 1, 32), sycl::range<3>(1, 1, 32)),
+        [=](sycl::nd_item<3> item_ct1) {
+            dequantize_block_nvfp4(vx, y, k);
+        });
+}
+
+
 template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
 static void dequantize_block_nc(const void * __restrict__ vx, dst_t * __restrict__ y,
         const int64_t ne00, const int64_t ne01, const int64_t ne02,
@@ -641,6 +653,8 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
             return dequantize_row_iq4_nl_sycl;
         case GGML_TYPE_MXFP4:
             return dequantize_row_mxfp4_sycl;
+        case GGML_TYPE_NVFP4:
+            return dequantize_row_nvfp4_sycl;
         case GGML_TYPE_F32:
             return convert_unary_sycl<float>;
 #ifdef GGML_SYCL_HAS_BF16
@@ -648,6 +662,7 @@ to_fp16_sycl_t ggml_get_to_fp16_sycl(ggml_type type, ggml_tensor * dst) {
             return convert_unary_sycl<sycl::ext::oneapi::bfloat16>;
 #endif
         default:
+            GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(type));
             return nullptr;
     }
 }
@@ -708,6 +723,8 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
             return dequantize_row_iq4_nl_sycl;
         case GGML_TYPE_MXFP4:
             return dequantize_row_mxfp4_sycl;
+        case GGML_TYPE_NVFP4:
+            return dequantize_row_nvfp4_sycl;
         case GGML_TYPE_F16:
             return convert_unary_sycl<sycl::half>;
 #ifdef GGML_SYCL_HAS_BF16
@@ -715,6 +732,7 @@ to_fp32_sycl_t ggml_get_to_fp32_sycl(ggml_type type, ggml_tensor *dst) {
             return convert_unary_sycl<sycl::ext::oneapi::bfloat16>;
 #endif
         default:
+            GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(type));
             return nullptr;
     }
 }
index da2a605daa8896fed95057fff22674c221a6874c..3272724f41b0d3e1e4b714d7a87be36dda6f5bb8 100644 (file)
@@ -838,4 +838,36 @@ static void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restr
     }
 }
 
+
+template <typename dst_t>
+static void dequantize_block_nvfp4(
+        const void * __restrict__ vx,
+        dst_t * __restrict__ yy,
+        const int64_t ne) {
+    auto          item_ct1 = sycl::ext::oneapi::this_work_item::get_nd_item<3>();
+    const int64_t i        = item_ct1.get_group(2);
+    const int     tid      = item_ct1.get_local_id(2);
+
+    const int64_t base = i * QK_NVFP4;
+    if (base >= ne) {
+        return;
+    }
+
+    const block_nvfp4 * x = (const block_nvfp4 *) vx;
+    const block_nvfp4 & xb = x[i];
+
+    const int sub = tid / (QK_NVFP4_SUB / 2);
+    const int j = tid % (QK_NVFP4_SUB / 2);
+
+    const float d = ggml_sycl_ue4m3_to_fp32(xb.d[sub]);
+    const uint8_t q = xb.qs[sub * (QK_NVFP4_SUB / 2) + j];
+
+    const int64_t y0 = base + sub * QK_NVFP4_SUB + j;
+    const int64_t y1 = y0 + QK_NVFP4_SUB / 2;
+
+    yy[y0] = ggml_sycl_cast<dst_t>(d * kvalues_mxfp4[q & 0x0F]);
+    yy[y1] = ggml_sycl_cast<dst_t>(d * kvalues_mxfp4[q >> 4]);
+}
+
+
 #endif // GGML_SYCL_DEQUANTIZE_HPP
index 316aa0d0fb58af9fc31394db1a300c39114d42fd..5abc50fabfe874de7ccad5ed07777126ab4e007b 100644 (file)
@@ -613,6 +613,23 @@ static void mul_mat_vec_mxfp4_q8_1_sycl(const void * vx, const void * vy, float
     }
 }
 
+static void mul_mat_vec_nvfp4_q8_1_sycl(const void * vx, const void * vy, float * dst, const int ncols, const int nrows,
+                                        dpct::queue_ptr stream) {
+    GGML_ASSERT(ncols % QK_NVFP4 == 0);
+    const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
+    const sycl::range<3> block_nums(1, 1, block_num_y);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+
+    {
+        stream->submit([&](sycl::handler & cgh) {
+            cgh.parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                             [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] {
+                                 mul_mat_vec_q<QK_NVFP4, QI_NVFP4, block_nvfp4, VDR_NVFP4_Q8_1_MMVQ, vec_dot_nvfp4_q8_1>(
+                                     vx, vy, dst, ncols, nrows, item_ct1);
+                             });
+        });
+    }
+}
 
 static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
                                        float *dst, const int ncols,
@@ -1145,8 +1162,11 @@ void ggml_sycl_op_mul_mat_vec_q(ggml_backend_sycl_context & ctx, const ggml_tens
             case GGML_TYPE_MXFP4:
                 mul_mat_vec_mxfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
                 break;
+            case GGML_TYPE_NVFP4:
+                mul_mat_vec_nvfp4_q8_1_sycl(src0_dd_i, src1_ddq_i_bs, dst_dd_i_bs, ne00, row_diff, stream);
+                break;
             default:
-                GGML_ABORT("fatal error");
+                GGML_ABORT("fatal error: unsupport data type=%s\n", ggml_type_name(src0->type));
         }
     }
     GGML_UNUSED(src1);
diff --git a/src/ggml-sycl/type.hpp b/src/ggml-sycl/type.hpp
new file mode 100644 (file)
index 0000000..d7ff89d
--- /dev/null
@@ -0,0 +1,112 @@
+#pragma once
+
+#include <sycl/sycl.hpp>
+#include <cstdint>
+#include <limits>
+
+inline uint8_t float_to_e4m3(float f)
+{
+    if (sycl::isnan(f)) {
+        return 0x7F;                    // Canonical NaN (positive)
+    }
+
+    uint32_t bits = sycl::bit_cast<uint32_t>(f);
+    uint32_t sign = (bits >> 31) & 0x1u;
+    uint32_t exp  = (bits >> 23) & 0xFFu;
+    uint32_t mant = bits & 0x7FFFFFu;
+
+    // Zero
+    if (exp == 0 && mant == 0) {
+        return static_cast<uint8_t>(sign << 7);
+    }
+
+    // Extract biased exponent and mantissa for FP8
+    int e = static_cast<int>(exp) - 127;           // true exponent (IEEE bias 127)
+    uint32_t m = mant;
+
+    // Handle very large values → NaN (NVIDIA behavior for E4M3)
+    if (e > 7) {                                   // max exponent for E4M3 is 7 (biased 14)
+        return static_cast<uint8_t>((sign << 7) | 0x7F);
+    }
+
+    // Handle subnormals and normal numbers
+    if (e < -6) {                                  // smallest normal exponent is -6
+        // Subnormal in FP8: shift mantissa right
+        int shift = -6 - e;
+        m = (m | 0x800000u) >> (shift + 1);        // +1 because we lose the implicit 1 position
+        if (shift > 23) m = 0;
+    } else {
+        // Normal number: adjust exponent bias from 127 to 7
+        int new_exp = e + 7;
+        m = (m >> 20) & 0x7u;                      // take top 3 mantissa bits (after implicit 1)
+        m |= (static_cast<uint32_t>(new_exp) << 3);
+    }
+
+    // Round-to-nearest-even (simple guard + round bit)
+    // For better accuracy you can add sticky bit, but this is sufficient for most use cases
+    uint32_t round_bit = (mant >> 19) & 0x1u;      // bit after the 3 mantissa bits
+    if (round_bit) {
+        m += 1;
+        // Carry into exponent if mantissa overflows
+        if ((m & 0x8u) != 0) {
+            m = (m & 0x7u) | ((m & 0x38u) << 1);   // simple carry handling
+            // If exponent overflows after carry → NaN
+            if ((m >> 3) > 14) {
+                return static_cast<uint8_t>((sign << 7) | 0x7F);
+            }
+        }
+    }
+
+    uint8_t result = static_cast<uint8_t>((sign << 7) | (m & 0x7F));
+    return result;
+}
+
+inline float e4m3_to_float(uint8_t x)
+{
+    if (x == 0) return 0.0f;
+
+    uint8_t sign = (x >> 7) & 0x1u;
+    uint8_t exp  = (x >> 3) & 0xFu;
+    uint8_t mant = x & 0x7u;
+
+    // NaN (NVIDIA uses 0x7F / 0xFF as NaN)
+    if (exp == 0xF && mant != 0) {
+        return std::numeric_limits<float>::quiet_NaN();
+    }
+    if (exp == 0xF) {                     // 0x7F or 0xFF treated as NaN
+        return std::numeric_limits<float>::quiet_NaN();
+    }
+
+    float val;
+
+    if (exp == 0) {
+        // Subnormal
+        val = mant * (1.0f / 8.0f) * sycl::pow(2.0f, -6.0f);
+    } else {
+        // Normal: implicit leading 1 + bias 7
+        val = (1.0f + mant / 8.0f) * sycl::pow(2.0f, static_cast<float>(exp) - 7.0f);
+    }
+
+    return sign ? -val : val;
+}
+
+// The actual type definition
+struct __nv_fp8_e4m3 {
+    uint8_t raw;
+
+    __nv_fp8_e4m3() = default;
+
+    explicit __nv_fp8_e4m3(float f) : raw(float_to_e4m3(f)) {}
+    explicit __nv_fp8_e4m3(sycl::half h) : raw(float_to_e4m3(static_cast<float>(h))) {}
+
+    operator float() const { return e4m3_to_float(raw); }
+    operator sycl::half() const { return static_cast<sycl::half>(static_cast<float>(*this)); }
+
+    // Allow direct access for vector loads/stores
+    operator uint8_t&() { return raw; }
+    operator uint8_t() const { return raw; }
+};
+
+using __nv_fp8x2_e4m3 = sycl::vec<__nv_fp8_e4m3, 2>;
+using __nv_fp8x4_e4m3 = sycl::vec<__nv_fp8_e4m3, 4>;
+
index 9a267d85a0cc917d818e5222c401be3b0a955481..eab9850aed7f0078570abb1ec4f44d795e52f630 100644 (file)
@@ -15,6 +15,7 @@
 
 #include "dpct/helper.hpp"
 #include "ggml.h"
+#include "type.hpp"
 #include "quants.hpp"
 
 typedef float (*vec_dot_q_sycl_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1,
@@ -31,6 +32,18 @@ static __dpct_inline__ int get_int_b1(const void * x, const int & i32) {
     return x32;
 }
 
+static __dpct_inline__ int get_int_b2(const void * x, const int & i32) {
+    const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment
+
+    int x32  = x16[2*i32 + 0] <<  0;
+    x32     |= x16[2*i32 + 1] << 16;
+
+    return x32;
+}
+
+static __dpct_inline__ int get_int_b4(const void * x, const int & i32) {
+    return ((const int *) x)[i32]; // assume at least 4 byte alignment
+}
 
 static __dpct_inline__ int get_int_from_int8(const int8_t* x8, const int& i32) {
   const uint16_t* x16 =
@@ -755,6 +768,35 @@ static __dpct_inline__ float vec_dot_mxfp4_q8_1(const void * __restrict__ vbq,
     return d * sumi;
 }
 
+#define VDR_NVFP4_Q8_1_MMVQ 4
+#define VDR_NVFP4_Q8_1_MMQ  8
+
+static __dpct_inline__ float vec_dot_nvfp4_q8_1(const void * __restrict__ vbq,
+                                                const block_q8_1 * __restrict__ bq8_1,
+                                                const int32_t & iqs) {
+    const block_nvfp4 * bq4 = (const block_nvfp4 *) vbq;
+    float sum = 0.0f;
+#pragma unroll
+    for (int i = 0; i < VDR_NVFP4_Q8_1_MMVQ/2; i++) {
+        const int32_t iqs0 = iqs + 2*i;
+        const int32_t iqs1 = iqs0 + 1;
+        const int32_t is = iqs0 >> 1;
+        const sycl::int2   v0   = get_int_from_table_16(get_int_b4(bq4->qs, iqs0), kvalues_mxfp4);
+        const sycl::int2   v1   = get_int_from_table_16(get_int_b4(bq4->qs, iqs1), kvalues_mxfp4);
+        const block_q8_1 * bq8 = bq8_1 + (is >> 1);
+        const int32_t i8 = ((is & 1) << 2);
+
+        int sumi = ggml_sycl_dp4a(v0.x(), get_int_b4(bq8->qs, i8 + 0), 0);
+        sumi     = ggml_sycl_dp4a(v0.y(), get_int_b4(bq8->qs, i8 + 2), sumi);
+        sumi     = ggml_sycl_dp4a(v1.x(), get_int_b4(bq8->qs, i8 + 1), sumi);
+        sumi     = ggml_sycl_dp4a(v1.y(), get_int_b4(bq8->qs, i8 + 3), sumi);
+
+        const float d = ggml_sycl_ue4m3_to_fp32(bq4->d[is]) * (bq8->ds)[0];
+        sum += d * float(sumi);
+    }
+
+    return sum;
+}
 
 static __dpct_inline__ float
 vec_dot_q5_0_q8_1(const void *__restrict__ vbq,