]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml-cpu: add repack for mxfp4 (llama/19738)
authorAman Gupta <redacted>
Fri, 27 Feb 2026 10:15:09 +0000 (18:15 +0800)
committerGeorgi Gerganov <redacted>
Mon, 16 Mar 2026 11:10:15 +0000 (13:10 +0200)
ggml/src/ggml-cpu/arch-fallback.h
ggml/src/ggml-cpu/arch/arm/repack.cpp
ggml/src/ggml-cpu/arch/x86/repack.cpp
ggml/src/ggml-cpu/repack.cpp
ggml/src/ggml-cpu/repack.h

index 4dfe28e1d649db8f1ed7478e046485000f52c914..ebbd4b47e0511367b66edd17c0ed2677bc94253d 100644 (file)
@@ -48,6 +48,8 @@
 #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
+#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0
+#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0
 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
@@ -62,6 +64,8 @@
 #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
+#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0
+#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
 #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
 #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
+#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0
 #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
+#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
 #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
 #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
 // repack.cpp
@@ -84,6 +90,7 @@
 #define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
 #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
+#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0
 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
 #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
+#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0
 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
 #elif defined(__POWERPC__) || defined(__powerpc__)
 #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
+#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0
+#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0
 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
+#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0
+#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
 #elif defined(__loongarch64)
 #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
+#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0
+#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0
 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
+#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0
+#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
 #elif defined(__riscv)
 #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
+#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0
+#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0
 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
+#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0
+#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
 #elif defined(__s390x__)
 #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
+#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0
+#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0
 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
+#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0
+#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
 #elif defined(__wasm__)
 #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
 #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
+#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0
+#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0
 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0
 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
 #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
 #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
+#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0
+#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0
 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0
 #endif
index c2e4623f37111390111dd70df449f7b70178e7df..3eed0105bf1ed9e8b3ca025e3401b6929708cdba 100644 (file)
@@ -498,6 +498,81 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
     ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
 }
 
+void ggml_gemv_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen = 4;
+
+    assert (n % qk == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    const int8x16_t kvalues = vld1q_s8(kvalues_mxfp4);
+    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+    float * res_ptr = s;
+
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb);
+
+        float32x4_t sumf = vdupq_n_f32(0);
+        for (int l = 0; l < nb; l++) {
+            uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0);
+            uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16);
+            uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32);
+            uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48);
+
+            int8x16_t b_0_hi = vqtbl1q_s8(kvalues, b_0 >> 4);
+            int8x16_t b_0_lo = vqtbl1q_s8(kvalues, b_0 & 0x0F);
+            int8x16_t b_1_hi = vqtbl1q_s8(kvalues, b_1 >> 4);
+            int8x16_t b_1_lo = vqtbl1q_s8(kvalues, b_1 & 0x0F);
+            int8x16_t b_2_hi = vqtbl1q_s8(kvalues, b_2 >> 4);
+            int8x16_t b_2_lo = vqtbl1q_s8(kvalues, b_2 & 0x0F);
+            int8x16_t b_3_hi = vqtbl1q_s8(kvalues, b_3 >> 4);
+            int8x16_t b_3_lo = vqtbl1q_s8(kvalues, b_3 & 0x0F);
+
+            int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0);
+            int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16);
+
+            int32x4_t sumi = vdupq_n_s32(0);
+            sumi = vdotq_laneq_s32(sumi, b_0_lo, a_0, 0);
+            sumi = vdotq_laneq_s32(sumi, b_0_hi, a_1, 0);
+            sumi = vdotq_laneq_s32(sumi, b_1_lo, a_0, 1);
+            sumi = vdotq_laneq_s32(sumi, b_1_hi, a_1, 1);
+            sumi = vdotq_laneq_s32(sumi, b_2_lo, a_0, 2);
+            sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2);
+            sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3);
+            sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3);
+
+            float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d));
+            float32x4_t b_d = {
+                GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[0]),
+                GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[1]),
+                GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[2]),
+                GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[3]),
+            };
+            float32x4_t d = a_d * b_d;
+
+            sumf = vmlaq_f32(sumf, d, vcvtq_f32_s32(sumi));
+        }
+
+        vst1q_f32(res_ptr + x * 4, sumf);
+    }
+    return;
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
+    ggml_gemv_mxfp4_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
+}
+
 void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
     constexpr int qk = QK_K;
     const int     nb = n / qk;
@@ -3164,6 +3239,87 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
     ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
 }
 
+void ggml_gemm_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen = 4;
+
+    assert (n % qk == 0);
+    assert (nr % 4 == 0);
+    assert (nc % ncols_interleaved == 0);
+
+    UNUSED(s);
+    UNUSED(bs);
+    UNUSED(vx);
+    UNUSED(vy);
+    UNUSED(nr);
+    UNUSED(nc);
+    UNUSED(nb);
+    UNUSED(ncols_interleaved);
+    UNUSED(blocklen);
+
+#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
+    const int8x16_t kvalues = vld1q_s8(kvalues_mxfp4);
+
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb);
+
+            float32x4_t sumf[4];
+            for (int m = 0; m < 4; m++) {
+                sumf[m] = vdupq_n_f32(0);
+            }
+
+            for (int l = 0; l < nb; l++) {
+                float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d));
+                float32x4_t b_d = {
+                    GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[0]),
+                    GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[1]),
+                    GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[2]),
+                    GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[3]),
+                };
+
+                int32x4_t sumi_0 = vdupq_n_s32(0);
+                int32x4_t sumi_1 = vdupq_n_s32(0);
+                int32x4_t sumi_2 = vdupq_n_s32(0);
+                int32x4_t sumi_3 = vdupq_n_s32(0);
+
+                for (int k = 0; k < 4; k++) {
+                    int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0);
+                    int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64);
+
+                    uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k);
+                    int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4);
+                    int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF);
+
+                    sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0);
+                    sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1);
+                    sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2);
+                    sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3);
+                    sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0);
+                    sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1);
+                    sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2);
+                    sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3);
+                }
+
+                sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0));
+                sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1));
+                sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2));
+                sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3));
+            }
+
+            for (int m = 0; m < 4; m++) {
+                vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]);
+            }
+        }
+    }
+    return;
+#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON)
+    ggml_gemm_mxfp4_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
+}
+
 void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
     constexpr int qk = QK_K;
     const int     nb = n / qk;
index 7dda9eea0c5a49624da07fbca84173462dbe0113..bd6906c4159e7243d9c9cce548931f1372607638 100644 (file)
@@ -522,7 +522,8 @@ template<typename block_tx8>
 static void gemv_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, __m256i signextendlut) {
     static_assert(
             std::is_same_v<block_tx8, block_q4_0x8> ||
-            std::is_same_v<block_tx8, block_iq4_nlx8>,
+            std::is_same_v<block_tx8, block_iq4_nlx8> ||
+            std::is_same_v<block_tx8, block_mxfp4x8>,
             "Unsupported block type");
 
     const int qk = QK8_0;
@@ -580,6 +581,18 @@ static void gemv_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
                         std::is_same_v<block_tx8, block_q4_0x8> ||
                         std::is_same_v<block_tx8, block_iq4_nlx8>) {
                     col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, changemask);
+                } else if constexpr (std::is_same_v<block_tx8, block_mxfp4x8>) {
+                    // Load 8 E8M0 exponents and convert to float via LUT
+                    // Rearranged to match changemask order: 0,4,1,5,2,6,3,7
+                    col_scale_f32 = _mm256_set_ps(
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[7]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[3]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[6]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[2]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[5]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[1]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[4]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[0]));
                 }
 
                 // Load and convert to FP32 scale from block_q8_0
@@ -628,7 +641,8 @@ template<typename block_tx8>
 static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, __m256i signextendlut) {
     static_assert(
             std::is_same_v<block_tx8, block_q4_0x8> ||
-            std::is_same_v<block_tx8, block_iq4_nlx8>,
+            std::is_same_v<block_tx8, block_iq4_nlx8> ||
+            std::is_same_v<block_tx8, block_mxfp4x8>,
             "Unsupported block type");
 
     const int qk = QK8_0;
@@ -749,6 +763,25 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
                         std::is_same_v<block_tx8, block_q4_0x8> ||
                         std::is_same_v<block_tx8, block_iq4_nlx8>) {
                     col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
+                } else if constexpr (std::is_same_v<block_tx8, block_mxfp4x8>) {
+                    //TODO: simd-ify
+                    col_scale_f32 = _mm512_set_ps(
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[7]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[6]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[5]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[4]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[3]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[2]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[1]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[0]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[7]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[6]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[5]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[4]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[3]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[2]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[1]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[0]));
                 }
 
                 // Process LHS in pairs of rows
@@ -941,6 +974,25 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
                         std::is_same_v<block_tx8, block_q4_0x8> ||
                         std::is_same_v<block_tx8, block_iq4_nlx8>) {
                     col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d);
+                } else if constexpr (std::is_same_v<block_tx8, block_mxfp4x8>) {
+                    //TODO: simd-ify
+                    col_scale_f32 = _mm512_set_ps(
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[7]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[6]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[5]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[4]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[3]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[2]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[1]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[0]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[7]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[6]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[5]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[4]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[3]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[2]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[1]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[0]));
                 }
 
                 // Load the four blocks of quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
@@ -1123,6 +1175,16 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
                         std::is_same_v<block_tx8, block_q4_0x8> ||
                         std::is_same_v<block_tx8, block_iq4_nlx8>) {
                     col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
+                } else if constexpr (std::is_same_v<block_tx8, block_mxfp4x8>) {
+                    col_scale_f32 = _mm256_set_ps(
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[7]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[6]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[5]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[4]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[3]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[2]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[1]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[0]));
                 }
 
                 // Process LHS in groups of four
@@ -1283,6 +1345,16 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t
                         std::is_same_v<block_tx8, block_q4_0x8> ||
                         std::is_same_v<block_tx8, block_iq4_nlx8>) {
                     col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d);
+                } else if constexpr (std::is_same_v<block_tx8, block_mxfp4x8>) {
+                    col_scale_f32 = _mm256_set_ps(
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[7]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[6]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[5]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[4]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[3]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[2]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[1]),
+                        GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[0]));
                 }
 
                 // Load the four blocks of quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3
@@ -1625,6 +1697,19 @@ void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
     ggml_gemv_iq4_nl_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
 }
 
+void ggml_gemv_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+#if defined(__AVX2__)
+    __m256i signextendlut = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i*)kvalues_mxfp4));
+    signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
+
+    gemv_q4_b32_8x8_q8_0_lut_avx<block_mxfp4x8>(n, s, bs, vx, vy, nr, nc, signextendlut);
+
+    return;
+#endif
+
+    ggml_gemv_mxfp4_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
+}
+
 void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
     const int qk = QK_K;
     const int nb = n / qk;
@@ -3423,6 +3508,21 @@ void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
     ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
+void ggml_gemm_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+#if defined(__AVX2__) || defined(__AVX512F__)
+    {
+        __m256i signextendlut = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i*)kvalues_mxfp4));
+        signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0);
+
+        gemm_q4_b32_8x8_q8_0_lut_avx<block_mxfp4x8>(n, s, bs, vx, vy, nr, nc, signextendlut);
+
+        return;
+    }
+#endif // defined(__AVX2__) || defined(__AVX512F__)
+
+    ggml_gemm_mxfp4_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc);
+}
+
 void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
     const int qk = QK_K;
     const int nb = n / qk;
index 1b3d23cbedcdc2f95c23bf607ce46ac5dce78f10..5edba4212f617ee124d734f8d6f406c9557dc858 100644 (file)
@@ -1098,6 +1098,82 @@ void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
     }
 }
 
+void ggml_gemv_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen = 4;
+
+    assert(nr == 1);
+    assert(n % qk == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(bs);
+    UNUSED(nr);
+
+    float sumf[4];
+    int sumi;
+
+    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb);
+
+        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
+        for (int l = 0; l < nb; l++) {
+            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumi = 0;
+                    for (int i = 0; i < blocklen; ++i) {
+                        const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
+                        const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
+                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));
+                    }
+                    sumf[j] += sumi * GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
+                }
+            }
+        }
+        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
+    }
+}
+
+void ggml_gemv_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 8;
+    const int blocklen = 8;
+
+    assert(nr == 1);
+    assert(n % qk == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    UNUSED(bs);
+    UNUSED(nr);
+
+    float sumf[8];
+    int sumi;
+
+    const block_q8_0 * a_ptr = (const block_q8_0 *) vy;
+    for (int x = 0; x < nc / ncols_interleaved; x++) {
+        const block_mxfp4x8 * b_ptr = (const block_mxfp4x8 *) vx + (x * nb);
+
+        for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0;
+        for (int l = 0; l < nb; l++) {
+            for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                for (int j = 0; j < ncols_interleaved; j++) {
+                    sumi = 0;
+                    for (int i = 0; i < blocklen; ++i) {
+                        const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
+                        const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
+                        sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2]));
+                    }
+                    sumf[j] += sumi * GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d);
+                }
+            }
+        }
+        for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j];
+    }
+}
+
 void ggml_gemv_q8_0_4x4_q8_0_generic(int                        n,
                                      float * GGML_RESTRICT      s,
                                      size_t                     bs,
@@ -1726,6 +1802,94 @@ void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
     }
 }
 
+void ggml_gemm_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 4;
+    const int blocklen = 4;
+
+    assert(n % qk == 0);
+    assert(nr % 4 == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    float sumf[4][4];
+    int sumi;
+
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb);
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
+            }
+            for (int l = 0; l < nb; l++) {
+                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                    for (int m = 0; m < 4; m++) {
+                        for (int j = 0; j < ncols_interleaved; j++) {
+                            sumi = 0;
+                            for (int i = 0; i < blocklen; ++i) {
+                                const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
+                                const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
+                                sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
+                                         (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4]));
+                            }
+                            sumf[m][j] += sumi * GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
+                        }
+                    }
+                }
+            }
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++)
+                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+            }
+        }
+    }
+}
+
+void ggml_gemm_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
+    const int qk = QK8_0;
+    const int nb = n / qk;
+    const int ncols_interleaved = 8;
+    const int blocklen = 8;
+
+    assert(n % qk == 0);
+    assert(nr % 4 == 0);
+    assert(nc % ncols_interleaved == 0);
+
+    float sumf[4][8];
+    int sumi;
+
+    for (int y = 0; y < nr / 4; y++) {
+        const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb);
+        for (int x = 0; x < nc / ncols_interleaved; x++) {
+            const block_mxfp4x8 * b_ptr = (const block_mxfp4x8 *) vx + (x * nb);
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0;
+            }
+            for (int l = 0; l < nb; l++) {
+                for (int k = 0; k < (qk / (2 * blocklen)); k++) {
+                    for (int m = 0; m < 4; m++) {
+                        for (int j = 0; j < ncols_interleaved; j++) {
+                            sumi = 0;
+                            for (int i = 0; i < blocklen; ++i) {
+                                const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F];
+                                const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4];
+                                sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) +
+                                         (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4]));
+                            }
+                            sumf[m][j] += sumi * GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]);
+                        }
+                    }
+                }
+            }
+            for (int m = 0; m < 4; m++) {
+                for (int j = 0; j < ncols_interleaved; j++)
+                    s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
+            }
+        }
+    }
+}
+
 void ggml_gemm_q8_0_4x4_q8_0_generic(int                        n,
                                      float * GGML_RESTRICT      s,
                                      size_t                     bs,
@@ -2510,6 +2674,121 @@ static int repack_iq4_nl_to_iq4_nl_8_bl(struct ggml_tensor * t, int interleave_b
     GGML_UNUSED(data_size);
 }
 
+
+static block_mxfp4x4 make_block_mxfp4x4(block_mxfp4 * in, unsigned int blck_size_interleave) {
+    block_mxfp4x4 out;
+
+    for (int i = 0; i < 4; i++) {
+        out.e[i] = in[i].e;
+    }
+
+    const int end = QK_MXFP4 * 2 / blck_size_interleave;
+
+    if (blck_size_interleave == 4) {
+        for (int i = 0; i < end; ++i) {
+            int src_id = i % 4;
+            int src_offset = (i / 4) * blck_size_interleave;
+            int dst_offset = i * blck_size_interleave;
+
+            memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint32_t));
+        }
+    } else {
+        GGML_ASSERT(false);
+    }
+
+    return out;
+}
+
+static int repack_mxfp4_to_mxfp4_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
+    GGML_ASSERT(t->type == GGML_TYPE_MXFP4);
+    GGML_ASSERT(interleave_block == 4);
+
+    const block_mxfp4   * src = (const block_mxfp4   *)data;
+          block_mxfp4x4 * dst = (      block_mxfp4x4 *)t->data;
+
+    block_mxfp4 dst_tmp[4];
+
+    int nrow = ggml_nrows(t);
+    int nrows_interleaved = 4;
+    int nblocks = t->ne[0] / QK_MXFP4;
+
+    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_mxfp4));
+
+    if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) {
+        return -1;
+    }
+
+    for (int b = 0; b < nrow; b += nrows_interleaved) {
+        for (int64_t x = 0; x < nblocks; x++) {
+            for (int i = 0; i < nrows_interleaved; i++) {
+                dst_tmp[i] = src[x + i * nblocks];
+            }
+            *dst++ = make_block_mxfp4x4(dst_tmp, interleave_block);
+        }
+        src += nrows_interleaved * nblocks;
+    }
+    return 0;
+
+    GGML_UNUSED(data_size);
+}
+
+static block_mxfp4x8 make_block_mxfp4x8(block_mxfp4 * in, unsigned int blck_size_interleave) {
+    block_mxfp4x8 out;
+
+    for (int i = 0; i < 8; i++) {
+        out.e[i] = in[i].e;
+    }
+
+    const int end = QK_MXFP4 * 4 / blck_size_interleave;
+
+    if (blck_size_interleave == 8) {
+        for (int i = 0; i < end; ++i) {
+            int src_id = i % 8;
+            int src_offset = (i / 8) * blck_size_interleave;
+            int dst_offset = i * blck_size_interleave;
+
+            memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t));
+        }
+    } else {
+        GGML_ASSERT(false);
+    }
+
+    return out;
+}
+
+static int repack_mxfp4_to_mxfp4_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
+    GGML_ASSERT(t->type == GGML_TYPE_MXFP4);
+    GGML_ASSERT(interleave_block == 8);
+
+    const block_mxfp4   * src = (const block_mxfp4   *)data;
+          block_mxfp4x8 * dst = (      block_mxfp4x8 *)t->data;
+
+    block_mxfp4 dst_tmp[8];
+
+    int nrow = ggml_nrows(t);
+    int nrows_interleaved = 8;
+    int nblocks = t->ne[0] / QK_MXFP4;
+
+    GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_mxfp4));
+
+    if (t->ne[1] % nrows_interleaved != 0) {
+        return -1;
+    }
+
+    for (int b = 0; b < nrow; b += nrows_interleaved) {
+        for (int64_t x = 0; x < nblocks; x++) {
+            for (int i = 0; i < nrows_interleaved; i++) {
+                dst_tmp[i] = src[x + i * nblocks];
+            }
+            *dst++ = make_block_mxfp4x8(dst_tmp, interleave_block);
+        }
+        src += nrows_interleaved * nblocks;
+    }
+    return 0;
+
+    GGML_UNUSED(data_size);
+}
+
 namespace ggml::cpu::repack {
 // repack
 template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS>
@@ -2569,6 +2848,14 @@ template <> int repack<block_iq4_nl, 8, 8>(struct ggml_tensor * t, const void *
     return repack_iq4_nl_to_iq4_nl_8_bl(t, 8, data, data_size);
 }
 
+template <> int repack<block_mxfp4, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
+    return repack_mxfp4_to_mxfp4_4_bl(t, 4, data, data_size);
+}
+
+template <> int repack<block_mxfp4, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
+    return repack_mxfp4_to_mxfp4_8_bl(t, 8, data, data_size);
+}
+
 template <> int repack<block_q8_0, 4, 4>(struct ggml_tensor * t, const void * data, size_t data_size) {
     return repack_q8_0_to_q8_0_4_bl(t, 4, data, data_size);
 }
@@ -2636,6 +2923,14 @@ template <> void gemv<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size
     ggml_gemv_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
+template <> void gemv<block_mxfp4, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemv_mxfp4_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemv<block_mxfp4, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemv_mxfp4_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
 template <> void gemv<block_q8_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
     ggml_gemv_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
 }
@@ -2703,6 +2998,14 @@ template <> void gemm<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size
     ggml_gemm_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
 }
 
+template <> void gemm<block_mxfp4, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_mxfp4_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
+template <> void gemm<block_mxfp4, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
+    ggml_gemm_mxfp4_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
+}
+
 template <> void gemm<block_q8_0, 4, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
     ggml_gemm_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc);
 }
@@ -3111,6 +3414,10 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
     static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 4, 4, GGML_TYPE_Q8_0> iq4_nl_4x4_q8_0;
     static const ggml::cpu::repack::tensor_traits<block_iq4_nl, 8, 8, GGML_TYPE_Q8_0> iq4_nl_8x8_q8_0;
 
+    // instance for MXFP4
+    static const ggml::cpu::repack::tensor_traits<block_mxfp4, 4, 4, GGML_TYPE_Q8_0> mxfp4_4x4_q8_0;
+    static const ggml::cpu::repack::tensor_traits<block_mxfp4, 8, 8, GGML_TYPE_Q8_0> mxfp4_8x8_q8_0;
+
     // instance for Q8_0
     static const ggml::cpu::repack::tensor_traits<block_q8_0, 4, 4, GGML_TYPE_Q8_0> q8_0_4x4_q8_0;
     static const ggml::cpu::repack::tensor_traits<block_q8_0, 8, 4, GGML_TYPE_Q8_0> q8_0_4x8_q8_0;
@@ -3187,6 +3494,17 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
                 return &iq4_nl_4x4_q8_0;
             }
         }
+    } else if (cur->type == GGML_TYPE_MXFP4) {
+        if (ggml_cpu_has_avx2()) {
+            if (cur->ne[1] % 8 == 0) {
+                return &mxfp4_8x8_q8_0;
+            }
+        }
+        if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
+            if (cur->ne[1] % 4 == 0) {
+                return &mxfp4_4x4_q8_0;
+            }
+        }
     } else if (cur->type == GGML_TYPE_Q8_0) {
         if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
             if (cur->ne[1] % 4 == 0) {
index ddf03d7642dba1dc2569726a4efcd1d1c3f95b13..b9f821630c46baa46424a9593b3729fd84714375 100644 (file)
@@ -97,6 +97,19 @@ struct block_iq4_nlx8 {
 
 static_assert(sizeof(block_iq4_nlx8) == 8 * sizeof(ggml_half) + QK4_NL * 4, "wrong iq4_nlx8 block size/padding");
 
+struct block_mxfp4x4 {
+    uint8_t e[4];
+    uint8_t qs[QK_MXFP4 * 2];
+};
+static_assert(sizeof(block_mxfp4x4) == 4 + QK_MXFP4 * 2, "wrong mxfp4x4 block size/padding");
+
+struct block_mxfp4x8 {
+    uint8_t e[8];
+    uint8_t qs[QK_MXFP4 * 4];
+};
+static_assert(sizeof(block_mxfp4x8) == 8 + QK_MXFP4 * 4, "wrong mxfp4x8 block size/padding");
+
+
 #if defined(__cplusplus)
 extern "C" {
 #endif
@@ -117,6 +130,8 @@ void ggml_gemv_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
 void ggml_gemv_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
@@ -129,6 +144,8 @@ void ggml_gemm_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
 void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
@@ -151,6 +168,8 @@ void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
 void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemv_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
@@ -163,6 +182,8 @@ void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
 void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
+void ggml_gemm_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
 void ggml_gemm_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);