]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : add NVFP4 quantization type support (llama/19769)
authorRichard Davison <redacted>
Wed, 11 Mar 2026 20:02:54 +0000 (21:02 +0100)
committerGeorgi Gerganov <redacted>
Sun, 15 Mar 2026 19:50:13 +0000 (21:50 +0200)
* WIP: add NVFP4 quantization support

* tests

* improve NVFP4 dot product implementation performance and fix bad super call

* typo

* Use nvfp4 kvalues

* vulkan : fix NVFP4 shader compilation by including kvalues_mxfp4 lookup table

* vulcal and perf fixes

* wip

* Fix metal

* fix vulcan

* Rename threshold & fix wrong scale

* Fix MOE

* Shelf backend implementations (CUDA, Metal, Vulkan, arch-specific SIMD)

Remove NVFP4 support from GPU backends and architecture-specific
optimized dot products. These should be added in separate PRs so
backend specialists can review them independently.

Reverted files:
- ggml-cuda: common.cuh, convert.cu, mmq.cu/cuh, mmvq.cu, vecdotq.cuh,
  quantize.cu/cuh, mma.cuh, ggml-cuda.cu, fattn-tile.cuh
- ggml-metal: ggml-metal.metal, ggml-metal-device.cpp, ggml-metal-impl.h,
  ggml-metal-ops.cpp
- ggml-vulkan: ggml-vulkan.cpp, all vulkan-shaders/*
- ggml-cpu arch: arm/quants.c, x86/quants.c, powerpc/quants.c, s390/quants.c

Core NVFP4 support (type definition, CPU fallback dot product,
quantization, dequantization, conversion) is retained.

* Fix arch-fallback.h: add NVFP4 generic fallback for all platforms

After shelving backend-specific SIMD implementations, the generic
CPU dot product needs to be aliased on ARM, x86, PowerPC, and s390
platforms that previously relied on arch-specific versions.

* quantize: add NVFP4 as a quantization type option

* Fix ggml_fp32_to_ue4m3: handle subnormal values

Previously, values with ue4m3_exp <= 0 were clamped to 0, causing
all small scales to underflow. This made NVFP4 quantization via
llama-quantize produce garbage (PPL = 5.8M) since typical transformer
weights have amax/6.0 in the range 0.001-0.01, which falls in the
UE4M3 subnormal range.

Now subnormals are properly encoded as man * 2^-9 (exp=0, man=1..7),
matching the decode path in ggml_ue4m3_to_fp32.

Result: NVFP4 requantization now produces PPL = 15.25 (vs F16 = 14.33),
comparable to Q4_1 (PPL = 15.81) at slightly lower BPW (4.70 vs 5.15).

* Restore ARM NEON NVFP4 dot product implementation

Restores the optimized ggml_vec_dot_nvfp4_q8_0 for ARM NEON using
vqtbl1q_s8 lookup and ggml_vdotq_s32 dot products.

tg128 performance: 4.37 t/s (generic) -> 13.66 t/s (NEON) = 3.1x speedup

* Optimize ARM NEON NVFP4 dot product: LUT + vpaddq + vfmaq

- Add ue4m3_scale_lut[128] to ggml-common.h replacing branch-heavy
  ggml_ue4m3_to_fp32() in the hot loop
- Use vpaddq_s32 for pairwise int32 reduction instead of vaddvq_s32
- Accumulate with vfmaq_f32 into float32x4_t vector accumulators

tg128: 8.1 -> 31.0 t/s (3.8x speedup, 77% of Q4_1 speed)

* ARM NEON NVFP4: rearrange q8 to match nibble layout

Alternative approach: rearrange q8 data to match the NVFP4 lo/hi
nibble layout instead of rearranging the looked-up NVFP4 values.
Eliminates vcombine_s8(vget_low, vget_low) shuffles.

Performance is equivalent (~18.5 t/s) - the bottleneck is the 2x
block overhead from QK=16 vs QK=32, not the shuffle instructions.

* CPU only backend 64 super-block layout

* cleanup

* Remove unused LUT

* int

* exclude NVFP4 from unsupported ops in metal build

* remove quantization for now

* store scales as native UE4M3, preserve original model bits when possible

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* correct comment

* format

* reduce duplication and cleanup

* Address comments

* move detection to prepare_tensors

* Use math instead of const

* Move

* fix comment

* Shelf quantize tests

* Rebase and move check

* cleanup

* lint

* Update gguf-py/gguf/scripts/gguf_convert_endian.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* Use fallback quant config

* Simplify

Co-authored-by: Sigbjørn Skjæret <redacted>
* organize

* Refactor

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* add quantize_nvfp4 (required for test_quants.py)

* add quantize_nvfp4 (required for test_quants.py)

* add quantize_nvfp4 (required for test_quants.py)

* fix return type

---------

Co-authored-by: Sigbjørn Skjæret <redacted>
15 files changed:
include/ggml.h
src/ggml-common.h
src/ggml-cpu/arch-fallback.h
src/ggml-cpu/arch/arm/quants.c
src/ggml-cpu/ggml-cpu.c
src/ggml-cpu/ops.cpp
src/ggml-cpu/quants.c
src/ggml-cpu/quants.h
src/ggml-impl.h
src/ggml-metal/ggml-metal-device.m
src/ggml-quants.c
src/ggml-quants.h
src/ggml.c
tests/test-backend-ops.cpp
tests/test-quantize-fns.cpp

index 566e2714790c18ddf90592325f9359f2bf88b062..3323f8e6c3fea805022bd44e9f11bd38da2cdb68 100644 (file)
@@ -427,7 +427,8 @@ extern "C" {
         // GGML_TYPE_IQ4_NL_4_8 = 37,
         // GGML_TYPE_IQ4_NL_8_8 = 38,
         GGML_TYPE_MXFP4   = 39, // MXFP4 (1 block)
-        GGML_TYPE_COUNT   = 40,
+        GGML_TYPE_NVFP4   = 40, // NVFP4 (4 blocks, E4M3 scale)
+        GGML_TYPE_COUNT   = 41,
     };
 
     // precision
@@ -463,6 +464,7 @@ extern "C" {
         GGML_FTYPE_MOSTLY_IQ1_M   = 23, // except 1d tensors
         GGML_FTYPE_MOSTLY_BF16    = 24, // except 1d tensors
         GGML_FTYPE_MOSTLY_MXFP4   = 25, // except 1d tensors
+        GGML_FTYPE_MOSTLY_NVFP4   = 26, // except 1d tensors
     };
 
     // available tensor operations:
index 93ab7ea446e26a33550a323f4c4d14d8266a3293..92cf739e7a7beab7c4ae1717ab270f73d41a01ff 100644 (file)
@@ -102,6 +102,9 @@ typedef sycl::half2 ggml_half2;
 #define QI_MXFP4 (QK_MXFP4 / (4 * QR_MXFP4))
 #define QR_MXFP4 2
 
+#define QI_NVFP4 (QK_NVFP4 / (4 * QR_NVFP4))
+#define QR_NVFP4 2
+
 #define QI5_0 (QK5_0 / (4 * QR5_0))
 #define QR5_0 2
 
@@ -194,6 +197,14 @@ typedef struct {
 } block_mxfp4;
 static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + QK_MXFP4/2, "wrong mxfp4 block size/padding");
 
+#define QK_NVFP4 64
+#define QK_NVFP4_SUB 16  // sub-block size for per-group scales
+typedef struct {
+    uint8_t d[QK_NVFP4/QK_NVFP4_SUB]; // UE4M3 scales (4 bytes, one per 16-element sub-block)
+    uint8_t qs[QK_NVFP4/2];           // packed 4-bit E2M1 values (32 bytes)
+} block_nvfp4;
+static_assert(sizeof(block_nvfp4) == sizeof(uint8_t)*(QK_NVFP4/QK_NVFP4_SUB) + QK_NVFP4/2, "wrong nvfp4 block size/padding");
+
 #define QK5_0 32
 typedef struct {
     ggml_half d;           // delta
index 48315610f2fc9eb52205f9a46047eed0d9ed779e..175aa4a4bb94f076c924cc1ce08da7a0e5e54e1c 100644 (file)
@@ -15,6 +15,7 @@
 #define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1
 #define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0
 #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
+#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
 #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
 #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
@@ -79,6 +80,8 @@
 #define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
 #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
 #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
+// quants.c
+#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
 // repack.cpp
 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
 // ref: https://github.com/ggml-org/llama.cpp/pull/14146#issuecomment-2972561679
 // quants.c
 #define quantize_row_q8_K_generic quantize_row_q8_K
+#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
 #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
 #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
 #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
 #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
 #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
+#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
 // repack.cpp
 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
 #define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
 #define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
 #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
+#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
 // repack.cpp
 #define ggml_quantize_mat_q8_0_4x1_generic ggml_quantize_mat_q8_0_4x1
 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
 #elif defined(__s390x__)
 // quants.c
 #define quantize_row_q8_K_generic quantize_row_q8_K
+#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
 #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
 #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
 #define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
 #define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
 #define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
+#define ggml_vec_dot_nvfp4_q8_0_generic ggml_vec_dot_nvfp4_q8_0
 // repack.cpp
 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
index a707d63985e0bd8fa31b89dc65987b99e817f7be..c1856201b317ab71471660cb02b2fe436ceabe16 100644 (file)
@@ -650,6 +650,90 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
     *s = sumf;
 }
 
+void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+    assert(n % QK_NVFP4 == 0);
+
+    const block_nvfp4 * GGML_RESTRICT x = vx;
+    const block_q8_0 * GGML_RESTRICT y = vy;
+
+    // Each NVFP4 super-block (64 elements) spans 2 q8_0 blocks
+    const int nb = n / QK_NVFP4;
+
+    float sumf = 0;
+
+#if defined __ARM_NEON
+    const int8x16_t values = vld1q_s8(kvalues_mxfp4);
+    const uint8x16_t m4b = vdupq_n_u8(0x0f);
+    float32x4_t acc = vdupq_n_f32(0.0f);
+
+    for (int ib = 0; ib < nb; ++ib) {
+        const uint8x16_t q4bits_0 = vld1q_u8(x[ib].qs);
+        const uint8x16_t q4bits_1 = vld1q_u8(x[ib].qs + 16);
+
+        const int8x16_t q4_lo_0 = ggml_vqtbl1q_s8(values, vandq_u8  (q4bits_0, m4b));
+        const int8x16_t q4_hi_0 = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits_0, 4));
+        const int8x16_t q4_lo_1 = ggml_vqtbl1q_s8(values, vandq_u8  (q4bits_1, m4b));
+        const int8x16_t q4_hi_1 = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits_1, 4));
+
+        const int8x16_t q8_0a = vld1q_s8(y[2*ib].qs);
+        const int8x16_t q8_0b = vld1q_s8(y[2*ib].qs + 16);
+        const int8x16_t q8_lo_0 = vcombine_s8(vget_low_s8(q8_0a), vget_low_s8(q8_0b));
+        const int8x16_t q8_hi_0 = vcombine_s8(vget_high_s8(q8_0a), vget_high_s8(q8_0b));
+
+        const int8x16_t q8_1a = vld1q_s8(y[2*ib+1].qs);
+        const int8x16_t q8_1b = vld1q_s8(y[2*ib+1].qs + 16);
+        const int8x16_t q8_lo_1 = vcombine_s8(vget_low_s8(q8_1a), vget_low_s8(q8_1b));
+        const int8x16_t q8_hi_1 = vcombine_s8(vget_high_s8(q8_1a), vget_high_s8(q8_1b));
+
+        const int32x4_t p0 = vaddq_s32(
+            ggml_vdotq_s32(vdupq_n_s32(0), q4_lo_0, q8_lo_0),
+            ggml_vdotq_s32(vdupq_n_s32(0), q4_hi_0, q8_hi_0));
+        const int32x4_t p1 = vaddq_s32(
+            ggml_vdotq_s32(vdupq_n_s32(0), q4_lo_1, q8_lo_1),
+            ggml_vdotq_s32(vdupq_n_s32(0), q4_hi_1, q8_hi_1));
+
+        const int32x4_t sums = vpaddq_s32(p0, p1);
+
+        // Decode 4 UE4M3 scales to f32 and multiply with q8 scales
+        const float dy0 = GGML_CPU_FP16_TO_FP32(y[2*ib].d);
+        const float dy1 = GGML_CPU_FP16_TO_FP32(y[2*ib+1].d);
+        const float32x4_t nvsc = {
+            ggml_ue4m3_to_fp32(x[ib].d[0]),
+            ggml_ue4m3_to_fp32(x[ib].d[1]),
+            ggml_ue4m3_to_fp32(x[ib].d[2]),
+            ggml_ue4m3_to_fp32(x[ib].d[3])
+        };
+        const float32x4_t scales = vmulq_f32(nvsc, (float32x4_t){dy0, dy0, dy1, dy1});
+
+        acc = vfmaq_f32(acc, vcvtq_f32_s32(sums), scales);
+    }
+    sumf = vaddvq_f32(acc);
+#else
+    for (int ib = 0; ib < nb; ++ib) {
+        for (int si = 0; si < 4; ++si) {
+            const float d = ggml_ue4m3_to_fp32(x[ib].d[si]);
+            const int q8b = si / 2;
+            const int q8o = (si % 2) * QK_NVFP4_SUB;
+            const float dy = GGML_CPU_FP16_TO_FP32(y[2*ib + q8b].d);
+
+            int sumi_lo = 0, sumi_hi = 0;
+            for (int j = 0; j < QK_NVFP4_SUB/2; ++j) {
+                const uint8_t qv = x[ib].qs[si*(QK_NVFP4_SUB/2) + j];
+                sumi_lo += y[2*ib + q8b].qs[q8o + j +               0] * kvalues_mxfp4[qv & 0xf];
+                sumi_hi += y[2*ib + q8b].qs[q8o + j + QK_NVFP4_SUB/2] * kvalues_mxfp4[qv >>  4];
+            }
+            sumf += dy * d * (sumi_lo + sumi_hi);
+        }
+    }
+#endif
+    *s = sumf;
+}
+
 void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
     const int qk = QK8_0;
     const int nb = n / qk;
index dc2b5ffaa77310b9656247cd2218e861b45cc17c..8b323bd9b06829c8a29e174b7ef1cc2e656f372a 100644 (file)
@@ -270,6 +270,12 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
         .vec_dot_type             = GGML_TYPE_Q8_0,
         .nrows                    = 1,
     },
+    [GGML_TYPE_NVFP4] = {
+        .from_float               = quantize_row_nvfp4,
+        .vec_dot                  = ggml_vec_dot_nvfp4_q8_0,
+        .vec_dot_type             = GGML_TYPE_Q8_0,
+        .nrows                    = 1,
+    },
     [GGML_TYPE_Q2_K] = {
         .from_float               = quantize_row_q2_K,
         .vec_dot                  = ggml_vec_dot_q2_K_q8_K,
index 331e071a2677a8c8e2bcedcffb59d5a6efbe0288..f9c4ec16e4b78c5e1cca8915ecc0cd8ae66761c3 100644 (file)
@@ -670,6 +670,7 @@ void ggml_compute_forward_add(
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_MXFP4:
+        case GGML_TYPE_NVFP4:
         case GGML_TYPE_Q2_K:
         case GGML_TYPE_Q3_K:
         case GGML_TYPE_Q4_K:
@@ -1119,6 +1120,7 @@ void ggml_compute_forward_add1(
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_Q8_1:
         case GGML_TYPE_MXFP4:
+        case GGML_TYPE_NVFP4:
         case GGML_TYPE_Q2_K:
         case GGML_TYPE_Q3_K:
         case GGML_TYPE_Q4_K:
@@ -1247,6 +1249,7 @@ void ggml_compute_forward_acc(
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_Q8_1:
         case GGML_TYPE_MXFP4:
+        case GGML_TYPE_NVFP4:
         case GGML_TYPE_Q2_K:
         case GGML_TYPE_Q3_K:
         case GGML_TYPE_Q4_K:
@@ -4334,6 +4337,7 @@ void ggml_compute_forward_out_prod(
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_MXFP4:
+        case GGML_TYPE_NVFP4:
         case GGML_TYPE_Q2_K:
         case GGML_TYPE_Q3_K:
         case GGML_TYPE_Q4_K:
@@ -4609,6 +4613,7 @@ void ggml_compute_forward_set(
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_Q8_1:
         case GGML_TYPE_MXFP4:
+        case GGML_TYPE_NVFP4:
         case GGML_TYPE_Q2_K:
         case GGML_TYPE_Q3_K:
         case GGML_TYPE_Q4_K:
@@ -4831,6 +4836,7 @@ void ggml_compute_forward_get_rows(
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_Q8_1:
         case GGML_TYPE_MXFP4:
+        case GGML_TYPE_NVFP4:
         case GGML_TYPE_Q2_K:
         case GGML_TYPE_Q3_K:
         case GGML_TYPE_Q4_K:
@@ -5555,6 +5561,7 @@ void ggml_compute_forward_clamp(
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_Q8_1:
         case GGML_TYPE_MXFP4:
+        case GGML_TYPE_NVFP4:
         case GGML_TYPE_Q2_K:
         case GGML_TYPE_Q3_K:
         case GGML_TYPE_Q4_K:
index 365cb36d2d764353f274882ee1f6323179a4cce4..7ebbb9c6f15bd070f257a1d42188bf2b24aadd36 100644 (file)
@@ -50,6 +50,10 @@ void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, i
     quantize_row_mxfp4_ref(x, y, k);
 }
 
+void quantize_row_nvfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
+    quantize_row_nvfp4_ref(x, y, k);
+}
+
 //
 // 2-6 bit quantization in super-blocks
 //
@@ -216,6 +220,42 @@ void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
     *s = sumf;
 }
 
+// NVFP4: super-block of 64 elements = 4 sub-blocks of 16 = 2 q8_0 blocks
+void ggml_vec_dot_nvfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+    assert(n % QK_NVFP4 == 0);
+
+    const block_nvfp4 * GGML_RESTRICT x = vx;
+    const block_q8_0 * GGML_RESTRICT y = vy;
+
+    const int nb = n / QK_NVFP4;
+
+    float sumf = 0;
+
+    for (int ib = 0; ib < nb; ++ib) {
+        for (int s_idx = 0; s_idx < 4; ++s_idx) {
+            const float d = ggml_ue4m3_to_fp32(x[ib].d[s_idx]);
+            const int q8_block = s_idx / 2;
+            const int q8_off   = (s_idx % 2) * QK_NVFP4_SUB;
+            const float dy = GGML_CPU_FP16_TO_FP32(y[2*ib + q8_block].d);
+
+            int sumi_lo = 0, sumi_hi = 0;
+            for (int j = 0; j < QK_NVFP4_SUB/2; ++j) {
+                const uint8_t qv = x[ib].qs[s_idx*(QK_NVFP4_SUB/2) + j];
+                sumi_lo += y[2*ib + q8_block].qs[q8_off + j +               0] * kvalues_mxfp4[qv & 0xf];
+                sumi_hi += y[2*ib + q8_block].qs[q8_off + j + QK_NVFP4_SUB/2] * kvalues_mxfp4[qv >>  4];
+            }
+
+            sumf += dy * d * (sumi_lo + sumi_hi);
+        }
+    }
+    *s = sumf;
+}
+
 void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
     const int qk = QK8_0;
     const int nb = n / qk;
index d83eb1b144d478bcd5f7a8db008e6a4f6fc43ead..3584aaa43e8c60d814bd34c106e5142b287e6d8c 100644 (file)
@@ -20,6 +20,7 @@ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
 void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
 
 void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_nvfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
 
 void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
 void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
@@ -42,6 +43,7 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
 void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 
 void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_nvfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 
 void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
@@ -73,6 +75,7 @@ void ggml_vec_dot_q5_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, c
 void ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 
 void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_nvfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 
 void ggml_vec_dot_tq1_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 void ggml_vec_dot_tq2_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
index e3714b38a6adf1cf49eb5a706e53ea4d757477db..92568655956a7b36e9a685ea33c9ec8445a09dd5 100644 (file)
@@ -491,6 +491,61 @@ static inline float ggml_e8m0_to_fp32_half(uint8_t x) {
 #define GGML_E8M0_TO_FP32(x) ggml_e8m0_to_fp32(x)
 #define GGML_E8M0_TO_FP32_HALF(x) ggml_e8m0_to_fp32_half(x)
 
+// UE4M3: unsigned, 4 exp bits (bias=7), 3 mantissa bits
+// Returns value * 0.5 to match kvalues_mxfp4 convention (kvalues = 2 * E2M1_float)
+static inline float ggml_ue4m3_to_fp32(uint8_t x) {
+    if (x == 0 || x == 0x7F) {
+        return 0.0f;
+    }
+    int   exp = (x >> 3) & 0xF;
+    int   man = x & 0x7;
+    float raw;
+    if (exp == 0) {
+        raw = ldexpf((float) man, -9);
+    } else {
+        raw = ldexpf(1.0f + (float) man / 8.0f, exp - 7);
+    }
+    return raw * 0.5f;
+}
+
+static inline uint8_t ggml_fp32_to_ue4m3(float x) {
+    if (!(x > 0.0f)) {
+        return 0;
+    }
+    if (x > 448.0f) {
+        x = 448.0f;
+    }
+    uint32_t bits;
+    memcpy(&bits, &x, 4);
+    int fp32_exp  = ((bits >> 23) & 0xFF) - 127;
+    int fp32_man  = (bits >> 20) & 0x7;
+    int ue4m3_exp = fp32_exp + 7;
+    if (ue4m3_exp <= 0) {
+        // subnormal: value = man * 2^-9, man = round(x * 2^9)
+        int man = (int) (x * 512.0f + 0.5f);
+        if (man > 7) {
+            man = 7;
+        }
+        if (man < 1) {
+            return 0;
+        }
+        return (uint8_t) man;
+    }
+    if (ue4m3_exp >= 15) {
+        return 0x7E;
+    }
+    int round_bit = (bits >> 19) & 1;
+    int ue4m3_man = fp32_man + round_bit;
+    if (ue4m3_man > 7) {
+        ue4m3_man = 0;
+        ue4m3_exp++;
+        if (ue4m3_exp >= 15) {
+            return 0x7E;
+        }
+    }
+    return (uint8_t) ((ue4m3_exp << 3) | ue4m3_man);
+}
+
 /**
  * Converts brain16 to float32.
  *
index 23bd2b2ab72e620a7192561cb567ef5079cbdfdc..d42b8ab1eb1874a2d0deece92697f7c8d5d36b21 100644 (file)
@@ -1158,7 +1158,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
         case GGML_OP_SOLVE_TRI:
         case GGML_OP_MUL_MAT:
         case GGML_OP_MUL_MAT_ID:
-            return has_simdgroup_reduction;
+            return has_simdgroup_reduction && op->src[0]->type != GGML_TYPE_NVFP4;
         case GGML_OP_SET:
         case GGML_OP_CPY:
         case GGML_OP_DUP:
@@ -1216,7 +1216,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
                 };
             }
         case GGML_OP_GET_ROWS:
-            return true;
+            return op->src[0]->type != GGML_TYPE_NVFP4;
         case GGML_OP_SET_ROWS:
             {
                 if (op->src[0]->type != GGML_TYPE_F32) {
index e8e25633fb86992c3e27a886b59f23b1dae001c3..cdaded865b16a5949f13fefe0587668e15afc730 100644 (file)
@@ -304,6 +304,41 @@ void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RE
     }
 }
 
+void quantize_row_nvfp4_ref(const float * GGML_RESTRICT x, block_nvfp4 * GGML_RESTRICT y, int64_t k) {
+    static const int qk = QK_NVFP4;
+    static const int qk_sub = QK_NVFP4_SUB;
+    static const int n_sub = QK_NVFP4 / QK_NVFP4_SUB;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        for (int s = 0; s < n_sub; s++) {
+            const float * xb = x + i*qk + s*qk_sub;
+
+            float amax = 0.0f;
+            for (int j = 0; j < qk_sub; j++) {
+                if (amax < fabsf(xb[j])) {
+                    amax = fabsf(xb[j]);
+                }
+            }
+
+            // UE4M3 scale: amax / 6.0 maps the max E2M1 value (6.0) to amax
+            const uint8_t ue = ggml_fp32_to_ue4m3(amax / 6.0f);
+            y[i].d[s] = ue;
+            const float d = ggml_ue4m3_to_fp32(ue);
+
+            for (int j = 0; j < qk_sub/2; ++j) {
+                const uint8_t x0 = best_index_mxfp4(xb[0        + j], d);
+                const uint8_t x1 = best_index_mxfp4(xb[qk_sub/2 + j], d);
+
+                y[i].qs[s*(qk_sub/2) + j] = x0 | (x1 << 4);
+            }
+        }
+    }
+}
+
 void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
     static const int qk = QK4_0;
 
@@ -434,6 +469,31 @@ void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_REST
     }
 }
 
+void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
+    static const int qk = QK_NVFP4;
+    static const int qk_sub = QK_NVFP4_SUB;
+    static const int n_sub = QK_NVFP4 / QK_NVFP4_SUB;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        for (int s = 0; s < n_sub; s++) {
+            const float d = ggml_ue4m3_to_fp32(x[i].d[s]);
+            float * yb = y + i*qk + s*qk_sub;
+
+            for (int j = 0; j < qk_sub/2; ++j) {
+                const int8_t v0 = kvalues_mxfp4[x[i].qs[s*(qk_sub/2) + j] & 0x0F];
+                const int8_t v1 = kvalues_mxfp4[x[i].qs[s*(qk_sub/2) + j] >>   4];
+
+                yb[j + 0       ] = v0*d;
+                yb[j + qk_sub/2] = v1*d;
+            }
+        }
+    }
+}
+
 //
 // 2-6 bit quantization in super-blocks
 //
@@ -2098,6 +2158,12 @@ size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
     return nrow * ggml_row_size(GGML_TYPE_MXFP4, n_per_row);
 }
 
+size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+    GGML_UNUSED(quant_weights);
+    quantize_row_nvfp4_ref(src, dst, (int64_t)nrow*n_per_row);
+    return nrow * ggml_row_size(GGML_TYPE_NVFP4, n_per_row);
+}
+
 // ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
 
 void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k) {
@@ -5244,6 +5310,12 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
             {
                 VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp4, data, nb);
             } break;
+        case GGML_TYPE_NVFP4:
+            {
+                // UE4M3 scales are uint8_t — all byte values are valid
+                GGML_UNUSED(data);
+                GGML_UNUSED(nb);
+            } break;
         case GGML_TYPE_Q2_K:
             {
                 VALIDATE_ROW_DATA_DM_F16_IMPL(block_q2_K, data, nb, d, dmin);
index 3b688f31c21459bdabe7094f067bd05998d1e37e..00604f75c0e9bc9112db4b4f93abb05978f2ddf2 100644 (file)
@@ -22,6 +22,7 @@ GGML_API void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 *
 GGML_API void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k);
 
 GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k);
+GGML_API void quantize_row_nvfp4_ref(const float * GGML_RESTRICT x, block_nvfp4 * GGML_RESTRICT y, int64_t k);
 
 GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k);
 GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k);
@@ -48,6 +49,7 @@ GGML_API void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GG
 //GGML_API void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
 
 GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+GGML_API void dequantize_row_nvfp4(const block_nvfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
 
 GGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
 GGML_API void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
@@ -95,6 +97,7 @@ GGML_API size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTR
 GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
 
 GGML_API size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+GGML_API size_t quantize_nvfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
 
 GGML_API void iq2xs_init_impl(enum ggml_type type);
 GGML_API void iq2xs_free_impl(enum ggml_type type);
index aeafc395d7100977889568753db701deab7732cd..e5b83e144799ee70009c9286c10eb9f363d426ff 100644 (file)
@@ -718,6 +718,14 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
         .to_float                 = (ggml_to_float_t) dequantize_row_mxfp4,
         .from_float_ref           = (ggml_from_float_t)quantize_row_mxfp4_ref,
     },
+    [GGML_TYPE_NVFP4] = {
+        .type_name                = "nvfp4",
+        .blck_size                = QK_NVFP4,
+        .type_size                = sizeof(block_nvfp4),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_nvfp4,
+        .from_float_ref           = (ggml_from_float_t)quantize_row_nvfp4_ref,
+    },
     [GGML_TYPE_Q2_K] = {
         .type_name                = "q2_K",
         .blck_size                = QK_K,
@@ -1374,6 +1382,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
         case GGML_FTYPE_MOSTLY_Q5_1:          wtype = GGML_TYPE_Q5_1;  break;
         case GGML_FTYPE_MOSTLY_Q8_0:          wtype = GGML_TYPE_Q8_0;  break;
         case GGML_FTYPE_MOSTLY_MXFP4:         wtype = GGML_TYPE_MXFP4; break;
+        case GGML_FTYPE_MOSTLY_NVFP4:         wtype = GGML_TYPE_NVFP4; break;
         case GGML_FTYPE_MOSTLY_Q2_K:          wtype = GGML_TYPE_Q2_K;  break;
         case GGML_FTYPE_MOSTLY_Q3_K:          wtype = GGML_TYPE_Q3_K;  break;
         case GGML_FTYPE_MOSTLY_Q4_K:          wtype = GGML_TYPE_Q4_K;  break;
@@ -7641,6 +7650,7 @@ size_t ggml_quantize_chunk(
         case GGML_TYPE_Q5_1:    result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
         case GGML_TYPE_Q8_0:    result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
         case GGML_TYPE_MXFP4:   result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+        case GGML_TYPE_NVFP4:   result = quantize_nvfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
         case GGML_TYPE_Q2_K:    result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
         case GGML_TYPE_Q3_K:    result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
         case GGML_TYPE_Q4_K:    result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
index 32a83b001d8c9fe68c27ace3f8e0b999d8d33249..66aaddcfffd522622b6cec675adf25e236db787f 100644 (file)
@@ -7854,10 +7854,6 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056, 1, 67,  {1,  1}, {4, 1}, {0, 2, 1, 3}));
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 16, 32, 32, { 1,  1}, {1, 1}, {0, 1, 2, 3}, 64, 3));
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 64, 77, 77, {12,1}, {1,1}));
-    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 2, 1, 3, {128, 1024}, {1, 1}));
-    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 2, 3, 4, {128, 1024}, {1, 1}));
-    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 2, 1, 3, {128*1024, 1}, {1, 1}, {0, 2, 1, 3}));
-    test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 2, 1, 3, {128*1024, 1}, {1, 1}, {0, 1, 2, 3}, 64));
 
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 576, 512, 576, {1,1}, {1,1}));
     test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 1, 2048, 8192, {1,  1}, {1, 1}));
index 037c0582bbbf8005cdb18729b0ddbd0adc63bb96..a8fb1926231ff87c428169db22fd3645388408db 100644 (file)
@@ -20,8 +20,10 @@ constexpr float MAX_QUANTIZATION_TOTAL_ERROR_TERNARY = 0.01f;
 constexpr float MAX_QUANTIZATION_TOTAL_ERROR_2BITS = 0.0075f;
 constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS = 0.0040f;
 constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS = 0.0050f;
+constexpr float MAX_QUANTIZATION_TOTAL_ERROR_FP4 = 0.0030f;
 constexpr float MAX_DOT_PRODUCT_ERROR = 0.02f;
 constexpr float MAX_DOT_PRODUCT_ERROR_LOWBIT = 0.04f;
+constexpr float MAX_DOT_PRODUCT_ERROR_FP4 = 0.03f;
 constexpr float MAX_DOT_PRODUCT_ERROR_TERNARY = 0.15f;
 
 static const char* RESULT_STR[] = {"ok", "FAILED"};
@@ -149,7 +151,8 @@ int main(int argc, char * argv[]) {
                 type == GGML_TYPE_IQ2_S   ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS :
                 type == GGML_TYPE_Q3_K    ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS :
                 type == GGML_TYPE_IQ3_S   ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS :
-                type == GGML_TYPE_IQ3_XXS ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS : MAX_QUANTIZATION_TOTAL_ERROR;
+                type == GGML_TYPE_IQ3_XXS ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS :
+                type == GGML_TYPE_NVFP4   ? MAX_QUANTIZATION_TOTAL_ERROR_FP4 : MAX_QUANTIZATION_TOTAL_ERROR;
             failed = !(total_error < max_quantization_error);
             num_failed += failed;
             if (failed || verbose) {
@@ -169,6 +172,8 @@ int main(int argc, char * argv[]) {
                                           ? MAX_DOT_PRODUCT_ERROR_LOWBIT
                                           : type == GGML_TYPE_TQ1_0 || type == GGML_TYPE_TQ2_0
                                           ? MAX_DOT_PRODUCT_ERROR_TERNARY
+                                          : type == GGML_TYPE_NVFP4
+                                          ? MAX_DOT_PRODUCT_ERROR_FP4
                                           : MAX_DOT_PRODUCT_ERROR;
             failed = !(vec_dot_error < max_allowed_error);
             num_failed += failed;