]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
gguf-py : add Numpy MXFP4 de/quantization support (#15111)
authorcompilade <redacted>
Fri, 8 Aug 2025 21:48:26 +0000 (17:48 -0400)
committerGitHub <redacted>
Fri, 8 Aug 2025 21:48:26 +0000 (17:48 -0400)
* gguf-py : add MXFP4 de/quantization support

* ggml-quants : handle zero amax for MXFP4

ggml/src/ggml-quants.c
gguf-py/gguf/quants.py
gguf-py/tests/test_quants.py

index a57d2a16d6c540e8baca40138525d4813246a57b..94f6405ca1e059fb92f70e2a4d675e8083dc15e6 100644 (file)
@@ -288,7 +288,7 @@ void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RE
             }
         }
 
-        const uint8_t e = (uint8_t) (floorf(log2f(amax)) - 2 + 127);
+        const uint8_t e = amax > 0.0f ? (uint8_t) (floorf(log2f(amax)) - 2 + 127) : 0;
 
         const float d = GGML_E8M0_TO_FP32_HALF(e);
 
index 3c8ba82e19d3d9e984ba39caf5cf865b0ee8e72a..31845ea6eebdadd0ee53ce3e8cb104b5c8e5b0df 100644 (file)
@@ -228,8 +228,7 @@ class Q4_0(__Quant, qtype=GGMLQuantizationType.Q4_0):
         d = max / -8
         with np.errstate(divide="ignore"):
             id = np.where(d == 0, 0, 1 / d)
-        # FIXME: Q4_0's reference rounding is cursed and depends on FMA
-        qs = np.trunc((np.float64(blocks) * np.float64(id)) + np.float64(8.5), dtype=np.float32).astype(np.uint8).clip(0, 15)
+        qs = np.trunc((blocks * id) + np.float32(8.5), dtype=np.float32).astype(np.uint8).clip(0, 15)
 
         qs = qs.reshape((n_blocks, 2, cls.block_size // 2))
         qs = qs[..., 0, :] | (qs[..., 1, :] << np.uint8(4))
@@ -300,8 +299,7 @@ class Q5_0(__Quant, qtype=GGMLQuantizationType.Q5_0):
         d = max / -16
         with np.errstate(divide="ignore"):
             id = np.where(d == 0, 0, 1 / d)
-        # FIXME: Q5_0's reference rounding is cursed and depends on FMA
-        q = np.trunc((np.float64(blocks) * np.float64(id)) + np.float64(16.5), dtype=np.float32).astype(np.uint8).clip(0, 31)
+        q = np.trunc((blocks * id) + np.float32(16.5), dtype=np.float32).astype(np.uint8).clip(0, 31)
 
         qs = q.reshape((n_blocks, 2, cls.block_size // 2))
         qs = (qs[..., 0, :] & np.uint8(0x0F)) | (qs[..., 1, :] << np.uint8(4))
@@ -655,6 +653,57 @@ class TQ2_0(__Quant, qtype=GGMLQuantizationType.TQ2_0):
         return (d * qs.astype(np.float32))
 
 
+class MXFP4(__Quant, qtype=GGMLQuantizationType.MXFP4):
+    # e2m1 values (doubled)
+    # ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
+    kvalues = (0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12)
+
+    @staticmethod
+    # see ggml_e8m0_to_fp32_half in ggml-impl.h
+    def e8m0_to_fp32_half(x: np.ndarray) -> np.ndarray:
+        bits = np.where(x < 2, np.uint32(0x00200000) << np.uint32(x), np.uint32(x - 1) << np.uint32(23))
+        return bits.view(np.float32)
+
+    @classmethod
+    def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+        n_blocks = blocks.shape[0]
+
+        d = abs(blocks).max(axis=-1, keepdims=True)
+
+        with np.errstate(divide="ignore"):
+            e = np.where(d > 0, np.floor(np.log2(d)) - 2 + 127, 0).astype(np.uint8)
+
+        d = cls.e8m0_to_fp32_half(e)
+
+        kvalues = np.array(cls.kvalues, dtype=np.int8).reshape((1, 1, 16))
+
+        errs = np.abs(d.reshape((n_blocks, 1, 1)) * kvalues.astype(np.float32) - blocks.reshape((n_blocks, cls.block_size, 1)))
+        best = np.argmin(errs, axis=-1, keepdims=True)
+
+        qs = best.reshape(n_blocks, 2, cls.block_size // 2).astype(np.uint8)
+        qs = qs[:, 0] | (qs[:, 1] << np.uint8(4))
+
+        qs = qs.reshape((n_blocks, cls.block_size // 2))
+
+        return np.concatenate([e, qs], axis=-1)
+
+    @classmethod
+    def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray:
+        n_blocks = blocks.shape[0]
+
+        e, qs = np.hsplit(blocks, [1])
+
+        d = cls.e8m0_to_fp32_half(e)
+
+        qs = qs.reshape((n_blocks, 1, cls.block_size // 2)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 2, 1))
+        qs = (qs & np.uint8(0x0F)).view(np.int8)
+
+        kvalues = np.array(cls.kvalues, dtype=np.int8).reshape(1, 1, 16)
+        qs = np.take_along_axis(kvalues, qs, axis=-1).reshape((n_blocks, cls.block_size))
+
+        return (d * qs.astype(np.float32))
+
+
 class IQ2_XXS(__Quant, qtype=GGMLQuantizationType.IQ2_XXS):
     ksigns: bytes = (
         b"\x00\x81\x82\x03\x84\x05\x06\x87\x88\x09\x0a\x8b\x0c\x8d\x8e\x0f"
index f04d5acce279325e5f45b27339c3166803d94e00..172fa0018ac401c91f62655db1e53bf4009af273 100755 (executable)
@@ -67,6 +67,7 @@ class GGMLQuants:
             "q4_0", "q4_1", "q5_0", "q5_1", "q8_0",
             "q2_K", "q3_K", "q4_K", "q5_K", "q6_K",
             "tq1_0", "tq2_0",
+            "mxfp4",
             "iq2_xxs", "iq2_xs", "iq2_s", "iq3_xxs", "iq3_s", "iq1_s", "iq1_m",
             "iq4_nl", "iq4_xs",
         ):
@@ -140,14 +141,21 @@ def compare_tensors(t1: np.ndarray, t2: np.ndarray, qtype: GGMLQuantizationType)
         return False
 
 
-def do_test(libggml_path: Path, quick: bool = False):
+def do_test(libggml_path: Path, quick: bool = False, user_type: GGMLQuantizationType | None = None):
     ggml_quants = GGMLQuants(libggml_path)
 
     np.set_printoptions(precision=None, threshold=(4 * 256) + 1, formatter={"int": lambda n: "0x%02X" % n})
 
     r = np.random.randn(8, 1024, 1024).astype(np.float32, copy=False)
-
-    for qtype in (GGMLQuantizationType.F16, *gguf.quants._type_traits.keys()):
+    # test zero blocks
+    r[0, 0, :] = 0
+    ## Maybe test infinities? (can make NANs, not really useful in practice)
+    # r[0, 1, 0] = np.inf
+    # r[0, 2, 0] = -np.inf
+    # r[0, 3, 0] = np.inf
+    # r[0, 3, 1] = -np.inf
+
+    for qtype in ((GGMLQuantizationType.F16, *gguf.quants._type_traits.keys()) if user_type is None else (user_type,)):
         has_dequantize = False
         has_quantize = False
 
@@ -228,11 +236,12 @@ def do_test(libggml_path: Path, quick: bool = False):
 
 if __name__ == "__main__":
     parser = argparse.ArgumentParser(description="Test Python (de)quantization against the reference C implementation")
-    parser.add_argument("--libggml", type=Path, default=Path(__file__).parent.parent.parent / "build" / "ggml" / "src" / "libggml.so", help="The path to libggml.so")
+    parser.add_argument("--libggml", type=Path, default=Path(__file__).parent.parent.parent / "build" / "bin" / "libggml.so", help="The path to libggml.so")
     parser.add_argument("--quick", action="store_true", help="Don't quantize with C when it's not strictly necessary")
+    parser.add_argument("--type", type=str, help="The quant type to test (all by default)")
 
     args = parser.parse_args()
 
     logging.basicConfig(level=logging.DEBUG)
 
-    do_test(args.libggml, args.quick)
+    do_test(args.libggml, args.quick, GGMLQuantizationType[args.type.upper()] if args.type is not None else None)