]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
llama : add gpt-oss (llama/15091)
authorGeorgi Gerganov <redacted>
Tue, 5 Aug 2025 19:10:36 +0000 (22:10 +0300)
committerGeorgi Gerganov <redacted>
Thu, 14 Aug 2025 11:17:28 +0000 (14:17 +0300)
* oai moe

* compat with new checkpoint

* add attn sink impl

* add rope scaling yarn

* logits match with latest transformers code

* wip chat template

* rm trailing space

* use ggml_scale_bias

* rm redundant is_swa_all

* convert interleaved gate_up

* graph : fix activation function to match reference (llama/7)

* vocab : handle o200k_harmony special tokens

* ggml : add attention sinks support (llama/1)

* llama : add attn sinks

* ggml : add attn sinks

* cuda : add attn sinks

* vulkan : add support for sinks in softmax

remove unnecessary return

* ggml : add fused swiglu_oai op (llama/11)

* ggml : add fused swiglu_oai op

* Update src/ggml-cpu/ops.cpp

Co-authored-by: Georgi Gerganov <redacted>
* update CUDA impl

* cont : metal impl

* add vulkan impl

* test-backend-ops : more test cases, clean up

* llama : remove unfused impl

* remove extra lines

---------

Co-authored-by: Georgi Gerganov <redacted>
---------

Co-authored-by: slaren <redacted>
* repack mxfp4 upon conversion

* clean up a bit

* enable thinking

* add quick hack to render only some special tokens

* fix bf16 conversion

* remove vocab hack

* webui ok

* support chat parsing for gpt-oss

* fix webui

* direct mapping mxfp4, FINALLY

* force using mxfp4

* properly use lazy tensor

* ggml : add mxfp4

ggml : use e8m0 conversion instead of powf

Co-authored-by: Diego Devesa <redacted>
change kvalues_mxfp4 table to match e2m1 (llama/6)

metal : remove quantization for now (not used)

cuda : fix disabled CUDA graphs due to ffn moe bias

vulkan : add support for mxfp4

cont : add cm2 dequant

* ggml : add ggml_add_id (llama/13)

* ggml : add ggml_add_id

* add cuda impl

* llama : add weight support check for add_id

* perf opt

* add vulkan impl

* rename cuda files

* add metal impl

* allow in-place ggml_add_id

* llama : keep biases on CPU with --cpu-moe

* llama : fix compile error

ggml-ci

* cuda : add fallback for __nv_cvt_e8m0_to_bf16raw

ggml-ci

* cleanup

ggml-ci

* sycl : fix supports_op for MXFP4

ggml-ci

* fix Unknown reasoning format

* ggml-cpu : fix AVX build

ggml-ci

* fix hip build

ggml-ci

* cuda : add mxfp4 dequantization support for cuBLAS

ggml-ci

* ggml-cpu : fix mxfp4 fallback definitions for some architectures

ggml-ci

* cuda : fix version required for __nv_cvt_e8m0_to_bf16raw

---------

Co-authored-by: Xuan Son Nguyen <redacted>
Co-authored-by: slaren <redacted>
59 files changed:
include/ggml.h
src/ggml-alloc.c
src/ggml-cann/ggml-cann.cpp
src/ggml-common.h
src/ggml-cpu/arch-fallback.h
src/ggml-cpu/arch/arm/quants.c
src/ggml-cpu/arch/x86/quants.c
src/ggml-cpu/ggml-cpu.c
src/ggml-cpu/ops.cpp
src/ggml-cpu/ops.h
src/ggml-cpu/quants.c
src/ggml-cpu/quants.h
src/ggml-cpu/vec.h
src/ggml-cuda/add-id.cu [new file with mode: 0644]
src/ggml-cuda/add-id.cuh [new file with mode: 0644]
src/ggml-cuda/common.cuh
src/ggml-cuda/convert.cu
src/ggml-cuda/fattn-common.cuh
src/ggml-cuda/fattn-mma-f16.cuh
src/ggml-cuda/fattn-tile-f16.cu
src/ggml-cuda/fattn-tile-f32.cu
src/ggml-cuda/fattn-vec-f16.cuh
src/ggml-cuda/fattn-vec-f32.cuh
src/ggml-cuda/fattn-wmma-f16.cu
src/ggml-cuda/fattn.cu
src/ggml-cuda/ggml-cuda.cu
src/ggml-cuda/im2col.cu
src/ggml-cuda/mmq.cu
src/ggml-cuda/mmq.cuh
src/ggml-cuda/mmvq.cu
src/ggml-cuda/softmax.cu
src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu [new file with mode: 0644]
src/ggml-cuda/unary.cu
src/ggml-cuda/unary.cuh
src/ggml-cuda/vecdotq.cuh
src/ggml-cuda/vendors/cuda.h
src/ggml-impl.h
src/ggml-metal/ggml-metal-impl.h
src/ggml-metal/ggml-metal.m
src/ggml-metal/ggml-metal.metal
src/ggml-opencl/ggml-opencl.cpp
src/ggml-quants.c
src/ggml-quants.h
src/ggml-sycl/ggml-sycl.cpp
src/ggml-vulkan/ggml-vulkan.cpp
src/ggml-vulkan/vulkan-shaders/add_id.comp [new file with mode: 0644]
src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp
src/ggml-vulkan/vulkan-shaders/dequant_funcs.comp
src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp
src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp [new file with mode: 0644]
src/ggml-vulkan/vulkan-shaders/glu_head.comp
src/ggml-vulkan/vulkan-shaders/mul_mm.comp
src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp
src/ggml-vulkan/vulkan-shaders/soft_max.comp
src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp [new file with mode: 0644]
src/ggml-vulkan/vulkan-shaders/types.comp
src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
src/ggml.c
tests/test-backend-ops.cpp

index 8a8775be365833d4c2b6faf7a9f5c8e8d68626db..2f06e1e39b225a022f0d4e813844d1326503ee13 100644 (file)
     GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne) \
     GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
 
+#define GGML_TENSOR_TERNARY_OP_LOCALS \
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb) \
+    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb) \
+    GGML_TENSOR_LOCALS(int64_t, ne2, src2, ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb2, src2, nb) \
+    GGML_TENSOR_LOCALS(int64_t, ne,  dst,  ne) \
+    GGML_TENSOR_LOCALS(size_t,  nb,  dst,  nb)
+
 #define GGML_TENSOR_BINARY_OP_LOCALS01 \
     GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
     GGML_TENSOR_LOCALS(size_t,  nb0, src0, nb) \
@@ -395,7 +405,8 @@ extern "C" {
         // GGML_TYPE_IQ4_NL_4_4 = 36,
         // GGML_TYPE_IQ4_NL_4_8 = 37,
         // GGML_TYPE_IQ4_NL_8_8 = 38,
-        GGML_TYPE_COUNT   = 39,
+        GGML_TYPE_MXFP4   = 39, // MXFP4 (1 block)
+        GGML_TYPE_COUNT   = 40,
     };
 
     // precision
@@ -430,6 +441,7 @@ extern "C" {
         GGML_FTYPE_MOSTLY_IQ4_XS  = 22, // except 1d tensors
         GGML_FTYPE_MOSTLY_IQ1_M   = 23, // except 1d tensors
         GGML_FTYPE_MOSTLY_BF16    = 24, // except 1d tensors
+        GGML_FTYPE_MOSTLY_MXFP4   = 25, // except 1d tensors
     };
 
     // available tensor operations:
@@ -438,6 +450,7 @@ extern "C" {
 
         GGML_OP_DUP,
         GGML_OP_ADD,
+        GGML_OP_ADD_ID,
         GGML_OP_ADD1,
         GGML_OP_ACC,
         GGML_OP_SUB,
@@ -557,6 +570,7 @@ extern "C" {
         GGML_GLU_OP_REGLU,
         GGML_GLU_OP_GEGLU,
         GGML_GLU_OP_SWIGLU,
+        GGML_GLU_OP_SWIGLU_OAI,
         GGML_GLU_OP_GEGLU_ERF,
         GGML_GLU_OP_GEGLU_QUICK,
 
@@ -831,6 +845,13 @@ extern "C" {
             struct ggml_tensor  * b,
             enum   ggml_type      type);
 
+    // dst[i0, i1, i2] = a[i0, i1, i2] + b[i0, ids[i1, i2]]
+    GGML_API struct ggml_tensor * ggml_add_id(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            struct ggml_tensor  * ids);
+
     GGML_API struct ggml_tensor * ggml_add1(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
@@ -1198,6 +1219,13 @@ extern "C" {
             struct ggml_tensor  * a,
             struct ggml_tensor  * b);
 
+    GGML_API struct ggml_tensor * ggml_swiglu_oai(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            float                 alpha,
+            float                 limit);
+
     // normalize along rows
     GGML_API struct ggml_tensor * ggml_norm(
             struct ggml_context * ctx,
@@ -1570,6 +1598,10 @@ extern "C" {
             float                 scale,
             float                 max_bias);
 
+    GGML_API void ggml_soft_max_add_sinks(
+            struct ggml_tensor * a,
+            struct ggml_tensor * sinks);
+
     GGML_API struct ggml_tensor * ggml_soft_max_ext_back(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
@@ -2052,6 +2084,10 @@ extern "C" {
     GGML_API enum ggml_prec ggml_flash_attn_ext_get_prec(
             const struct ggml_tensor * a);
 
+    GGML_API void ggml_flash_attn_ext_add_sinks(
+            struct ggml_tensor * a,
+            struct ggml_tensor * sinks);
+
     // TODO: needs to be adapted to ggml_flash_attn_ext
     GGML_API struct ggml_tensor * ggml_flash_attn_back(
            struct ggml_context * ctx,
index fcc552da519b135ab465dfcdd9a0aebe6a365e19..8b6e6028361d01b06c71d2e0f89994704ebe16c5 100644 (file)
@@ -29,6 +29,7 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
         case GGML_OP_DIAG_MASK_ZERO:
         case GGML_OP_DIAG_MASK_INF:
         case GGML_OP_ADD:
+        case GGML_OP_ADD_ID:
         case GGML_OP_ADD1:
         case GGML_OP_SUB:
         case GGML_OP_MUL:
index 8eb8b1470bb701b09067f3185ce0d0be29051290..551c1f7673abaf60068b3bf03ae08f0fcd191e5a 100755 (executable)
@@ -2340,6 +2340,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
             memcpy(&bias, (float*)op->op_params + 1, sizeof(float));
             return bias == 0.0f; // TODO: support bias != 0.0f
         case GGML_OP_SOFT_MAX:
+            // TODO: support attention sinks [TAG_ATTN_SINKS]
+            if (op->src[2]) {
+                return false;
+            }
             // TODO: support broadcast
             // ref: https://github.com/ggml-org/llama.cpp/pull/14435
             return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
@@ -2354,6 +2358,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
             if(op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_BF16){
                 return false;
             }
+            // TODO: support attention sinks [TAG_ATTN_SINKS]
+            if (op->src[4]) {
+                return false;
+            }
             if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
                 // different head sizes of K and V are not supported yet
                 return false;
index fbb04426abe7e600f8533a9d89e38b6ef38bc465..93ab7ea446e26a33550a323f4c4d14d8266a3293 100644 (file)
@@ -99,6 +99,9 @@ typedef sycl::half2 ggml_half2;
 #define QI4_1 (QK4_1 / (4 * QR4_1))
 #define QR4_1 2
 
+#define QI_MXFP4 (QK_MXFP4 / (4 * QR_MXFP4))
+#define QR_MXFP4 2
+
 #define QI5_0 (QK5_0 / (4 * QR5_0))
 #define QR5_0 2
 
@@ -184,6 +187,13 @@ typedef struct {
 } block_q4_1;
 static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding");
 
+#define QK_MXFP4 32
+typedef struct {
+    uint8_t e; // E8M0
+    uint8_t qs[QK_MXFP4/2];
+} block_mxfp4;
+static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + QK_MXFP4/2, "wrong mxfp4 block size/padding");
+
 #define QK5_0 32
 typedef struct {
     ggml_half d;           // delta
@@ -1074,10 +1084,17 @@ GGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512)
     0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
 GGML_TABLE_END()
 
+// TODO: fix name to kvalues_iq4_nl
 GGML_TABLE_BEGIN(int8_t, kvalues_iq4nl, 16)
     -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113,
 GGML_TABLE_END()
 
+// e2m1 values (doubled)
+// ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
+GGML_TABLE_BEGIN(int8_t, kvalues_mxfp4, 16)
+    0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12,
+GGML_TABLE_END()
+
 #define NGRID_IQ1S 2048
 #define IQ1S_DELTA 0.125f
 #define IQ1M_DELTA 0.125f
index f02cfe8fa5c0651b5753a3bfcd68d9c17052b34b..b62e3158ee74114c8b5bb87bf7e958ee8f64aea2 100644 (file)
@@ -13,6 +13,7 @@
 #define ggml_vec_dot_q5_0_q8_0_generic ggml_vec_dot_q5_0_q8_0
 #define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1
 #define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0
+#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
 #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
 #define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
@@ -68,6 +69,7 @@
 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
 #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
 #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
+#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
 // repack.cpp
 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
@@ -90,6 +92,7 @@
 #define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
 #define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
 #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
+#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
 // repack.cpp
 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
 #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
 #define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
 #define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
+#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
 // repack.cpp
 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
 #define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K
 #define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K
 #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
+#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
 // repack.cpp
 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
 #define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
 #define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
 #define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
+#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
 // repack.cpp
 #define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
 #define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
index c6d1d852e9ad44fe06ec7b1e8a68af4b0e710f00..aadbb487ec0e4e63ff2abbec49b1618397438858 100644 (file)
@@ -589,6 +589,67 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
     *s = sumf;
 }
 
+void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+    assert(n % QK_MXFP4 == 0);
+    static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
+
+    const block_mxfp4 * GGML_RESTRICT x = vx;
+    const block_q8_0 * GGML_RESTRICT y = vy;
+
+    const int nb = n / QK_MXFP4;
+
+    int ib = 0;
+    float sumf = 0;
+
+#if defined __ARM_NEON
+    const int8x16_t values = vld1q_s8(kvalues_mxfp4);
+    const uint8x16_t m4b = vdupq_n_u8(0x0f);
+    uint8x16x2_t q4bits;
+    int8x16x4_t q4b;
+    int8x16x4_t q8b;
+    int32x4_t prod_1;
+    int32x4_t prod_2;
+
+    for (; ib + 1 < nb; ib += 2) {
+        q4bits.val[0] = vld1q_u8(x[ib + 0].qs);
+        q4bits.val[1] = vld1q_u8(x[ib + 1].qs);
+        q8b.val[0]    = vld1q_s8(y[ib + 0].qs);
+        q8b.val[1]    = vld1q_s8(y[ib + 0].qs + 16);
+        q8b.val[2]    = vld1q_s8(y[ib + 1].qs);
+        q8b.val[3]    = vld1q_s8(y[ib + 1].qs + 16);
+
+        q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8  (q4bits.val[0], m4b));
+        q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
+        q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8  (q4bits.val[1], m4b));
+        q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));
+
+        prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);
+        prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);
+
+        sumf +=
+            GGML_E8M0_TO_FP32_HALF(x[ib + 0].e) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * vaddvq_s32(prod_1) +
+            GGML_E8M0_TO_FP32_HALF(x[ib + 1].e) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * vaddvq_s32(prod_2);
+    }
+
+#endif
+    for (; ib < nb; ++ib) {
+        const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
+        int sumi1 = 0;
+        int sumi2 = 0;
+        for (int j = 0; j < QK_MXFP4/2; ++j) {
+            sumi1 += y[ib].qs[j +          0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
+            sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >>  4];
+        }
+        sumf += d * (sumi1 + sumi2);
+    }
+    *s = sumf;
+}
+
 void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
     const int qk = QK8_0;
     const int nb = n / qk;
index 30fd59f7028a970bf8a8cb2edf1eb43ebea2c42a..cb49320a67f12a04f7879872f11dabdf1c89adcb 100644 (file)
@@ -66,6 +66,12 @@ static inline int hsum_i32_4(const __m128i a) {
 }
 
 #if defined(__AVX2__) || defined(__AVX512F__)
+static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
+    const __m256i ax = _mm256_sign_epi8(x, x);
+    const __m256i sy = _mm256_sign_epi8(y, x);
+    return _mm256_maddubs_epi16(ax, sy);
+}
+
 // spread 32 bits to 32 bytes { 0x00, 0xFF }
 static inline __m256i bytes_from_bits_32(const uint8_t * x) {
     uint32_t x32;
@@ -261,6 +267,11 @@ static inline __m256 quad_fp16_delta_float(const float x0, const float y0, const
     return _mm256_set_m128(_mm_set1_ps(GGML_CPU_FP16_TO_FP32(x1) * GGML_CPU_FP16_TO_FP32(y1)),
                            _mm_set1_ps(GGML_CPU_FP16_TO_FP32(x0) * GGML_CPU_FP16_TO_FP32(y0)));
 }
+
+static inline __m256 quad_mx_delta_float(const int8_t x0, const float y0, const int8_t x1, const float y1) {
+    return _mm256_set_m128(_mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x1) * GGML_CPU_FP16_TO_FP32(y1)),
+                           _mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x0) * GGML_CPU_FP16_TO_FP32(y0)));
+}
 #endif
 #elif defined(__SSSE3__)
 // horizontally add 4x4 floats
@@ -746,6 +757,91 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
 #endif
 }
 
+void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+    assert(n % QK_MXFP4 == 0);
+    static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
+
+    const block_mxfp4 * GGML_RESTRICT x = vx;
+    const block_q8_0 * GGML_RESTRICT y = vy;
+
+    const int nb = n / QK_MXFP4;
+
+    int ib = 0;
+    float sumf = 0;
+
+#if defined __AVX2__
+
+    const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
+    const __m128i m4b  = _mm_set1_epi8(0x0f);
+    const __m256i mone = _mm256_set1_epi16(1);
+
+    __m256 accum1 = _mm256_setzero_ps();
+    __m256 accum2 = _mm256_setzero_ps();
+    for (; ib + 1 < nb; ib += 2) {
+        const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs);
+        const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs);
+        const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib + 0].qs);
+        const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib + 1].qs);
+        const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
+                                              _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
+        const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
+                                              _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
+        const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
+        const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
+        const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
+        const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);
+        accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 0].e)),
+                _mm256_cvtepi32_ps(p_1), accum1);
+        accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 1].e)),
+                _mm256_cvtepi32_ps(p_2), accum2);
+    }
+
+    sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
+
+#elif defined __AVX__
+    const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
+    const __m128i m4b  = _mm_set1_epi8(0x0f);
+
+    __m256 accum = _mm256_setzero_ps();
+    for (; ib + 1 < nb; ib += 2) {
+        const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
+        const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
+        const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);
+        const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);
+        const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
+        const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
+
+        const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
+        const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
+        const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));
+        const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));
+
+        const __m256 p = mul_sum_i8_quad_float(q4b_1_0, q4b_1_1, q4b_2_0, q4b_2_1, q8b_1_0, q8b_1_1, q8b_2_0, q8b_2_1);
+        const __m256 deltas = quad_mx_delta_float(x[ib].e, y[ib].d, x[ib + 1].e, y[ib + 1].d);
+        accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);
+    }
+
+    sumf = hsum_float_8(accum);
+
+#endif
+    for (; ib < nb; ++ib) {
+        const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
+        int sumi1 = 0;
+        int sumi2 = 0;
+        for (int j = 0; j < QK_MXFP4/2; ++j) {
+            sumi1 += y[ib].qs[j +          0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
+            sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >>  4];
+        }
+        sumf += d * (sumi1 + sumi2);
+    }
+    *s = sumf;
+}
+
 void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
     const int qk = QK8_0;
     const int nb = n / qk;
@@ -3206,14 +3302,6 @@ void ggml_vec_dot_iq3_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
 #endif
 }
 
-#if defined(__AVX2__)
-static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
-    const __m256i ax = _mm256_sign_epi8(x, x);
-    const __m256i sy = _mm256_sign_epi8(y, x);
-    return _mm256_maddubs_epi16(ax, sy);
-}
-#endif
-
 void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
     assert(n % QK_K == 0);
     assert(nrc == 1);
index c5271b77572289a9dd670aac7014112882de9409..d89cd8f4ef6526540f038ddb03e51b24068c3bcb 100644 (file)
@@ -253,6 +253,12 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
         .vec_dot_type             = GGML_TYPE_Q8_1,
         .nrows                    = 1,
     },
+    [GGML_TYPE_MXFP4] = {
+        .from_float               = quantize_row_mxfp4,
+        .vec_dot                  = ggml_vec_dot_mxfp4_q8_0,
+        .vec_dot_type             = GGML_TYPE_Q8_0,
+        .nrows                    = 1,
+    },
     [GGML_TYPE_Q2_K] = {
         .from_float               = quantize_row_q2_K,
         .vec_dot                  = ggml_vec_dot_q2_K_q8_K,
@@ -1670,6 +1676,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_add(params, tensor);
             } break;
+        case GGML_OP_ADD_ID:
+            {
+                ggml_compute_forward_add_id(params, tensor);
+            } break;
         case GGML_OP_ADD1:
             {
                 ggml_compute_forward_add1(params, tensor);
@@ -1924,7 +1934,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             } break;
         case GGML_OP_FLASH_ATTN_EXT:
             {
-                ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
+                ggml_compute_forward_flash_attn_ext(params, tensor);
             } break;
         case GGML_OP_FLASH_ATTN_BACK:
             {
@@ -2111,6 +2121,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
         case GGML_OP_DUP:
         case GGML_OP_CONT:
         case GGML_OP_ADD:
+        case GGML_OP_ADD_ID:
         case GGML_OP_ADD1:
         case GGML_OP_ACC:
             {
@@ -2172,6 +2183,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
                 case GGML_GLU_OP_REGLU:
                 case GGML_GLU_OP_GEGLU:
                 case GGML_GLU_OP_SWIGLU:
+                case GGML_GLU_OP_SWIGLU_OAI:
                 case GGML_GLU_OP_GEGLU_ERF:
                 case GGML_GLU_OP_GEGLU_QUICK:
                     {
@@ -2673,6 +2685,7 @@ struct ggml_cplan ggml_graph_plan(
                         }
                     } break;
                 case GGML_OP_ADD:
+                case GGML_OP_ADD_ID:
                 case GGML_OP_ADD1:
                     {
                         if (ggml_is_quantized(node->src[0]->type)) {
index 6581d27adde2e077376956ad604ab5723ea18b30..854f1c2b49647b8ff7a42e27d738c8bd6e6d7164 100644 (file)
@@ -8,6 +8,7 @@
 #include "vec.h"
 
 #include <float.h>
+#include <algorithm>
 
 // ggml_compute_forward_dup
 
@@ -1283,6 +1284,7 @@ void ggml_compute_forward_add(
         case GGML_TYPE_Q5_0:
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
+        case GGML_TYPE_MXFP4:
         case GGML_TYPE_Q2_K:
         case GGML_TYPE_Q3_K:
         case GGML_TYPE_Q4_K:
@@ -1309,6 +1311,77 @@ void ggml_compute_forward_add(
     }
 }
 
+// ggml_compute_forward_add_id
+
+static void ggml_compute_forward_add_id_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    const ggml_tensor * src2 = dst->src[2];
+
+    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(src2->type == GGML_TYPE_I32);
+
+    GGML_ASSERT(src0->nb[0] == sizeof(float));
+    GGML_ASSERT(src1->nb[0] == sizeof(float));
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr  = ggml_nrows(src0);
+
+    GGML_TENSOR_TERNARY_OP_LOCALS
+
+    GGML_ASSERT( nb0 == sizeof(float));
+    GGML_ASSERT(nb10 == sizeof(float));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 indices
+        const int i3 = ir/(ne2*ne1);
+        const int i2 = (ir - i3*ne2*ne1)/ne1;
+        const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+        // src1 indices
+        const int i11 = *(int32_t *) ((char *) src2->data + i1*nb20 + i2*nb21);
+
+        GGML_ASSERT(i11 >= 0 && i11 < ne11);
+
+        ggml_vec_add_f32(ne0,
+                (float *) ((char *) dst->data  + i3*nb3  + i2*nb2  + i1*nb1 ),
+                (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
+                (float *) ((char *) src1->data + i11*nb11));
+    }
+}
+
+void ggml_compute_forward_add_id(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_add_id_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("unsupported type for ggml_compute_forward_add_id: %s", ggml_type_name(src0->type));
+            }
+    }
+}
+
 // ggml_compute_forward_add1
 
 static void ggml_compute_forward_add1_f32(
@@ -1660,6 +1733,7 @@ void ggml_compute_forward_add1(
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_Q8_1:
+        case GGML_TYPE_MXFP4:
         case GGML_TYPE_Q2_K:
         case GGML_TYPE_Q3_K:
         case GGML_TYPE_Q4_K:
@@ -1787,6 +1861,7 @@ void ggml_compute_forward_acc(
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_Q8_1:
+        case GGML_TYPE_MXFP4:
         case GGML_TYPE_Q2_K:
         case GGML_TYPE_Q3_K:
         case GGML_TYPE_Q4_K:
@@ -3614,6 +3689,93 @@ static void ggml_compute_forward_swiglu(
     }
 }
 
+// ggml_compute_forward_swiglu_oai
+
+static void ggml_compute_forward_swiglu_oai_f32(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    char * src0_d = (char *) src0->data;
+    char * src1_d = (char *) (src1 ? src1->data : src0->data);
+    const size_t src0_o = src0->nb[1];
+    const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
+
+    GGML_ASSERT(ggml_is_contiguous_1(src0));
+    GGML_ASSERT(ggml_is_contiguous_1(dst));
+
+    if (src1) {
+        GGML_ASSERT(ggml_is_contiguous_1(src1));
+        GGML_ASSERT(src0->type == src1->type);
+    }
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
+    const int nr = ggml_nrows(src0);
+
+    GGML_ASSERT(dst->ne[0] == nc);
+    GGML_ASSERT(ggml_nrows(dst) == nr);
+
+    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
+    const float alpha = ggml_get_op_params_f32(dst, 2);
+    const float limit = ggml_get_op_params_f32(dst, 3);
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int i1 = ir0; i1 < ir1; i1++) {
+        float * src0_p = (float *) (src0_d + i1*src0_o);
+        float * src1_p = (float *) (src1_d + i1*src1_o);
+        float * dst_p  = (float *) ((char *) dst->data + i1*(dst->nb[1]));
+
+        if (!src1) {
+            src0_p += swapped ? nc : 0;
+            src1_p += swapped ? 0 : nc;
+        }
+
+        for (int k = 0; k < nc; k++) {
+            const float x = std::min(src0_p[k], limit);
+            const float y = std::clamp(src1_p[k], -limit, limit);
+            const float out_glu = x / (1.f + expf(alpha * (-x)));
+            dst_p[k] = out_glu * (y + 1.f);
+        }
+
+#ifndef NDEBUG
+        for (int k = 0; k < nc; k++) {
+            const float x = dst_p[k];
+            GGML_UNUSED(x);
+            assert(!isnan(x));
+            assert(!isinf(x));
+        }
+#endif
+    }
+}
+
+static void ggml_compute_forward_swiglu_oai(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_swiglu_oai_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("fatal error");
+            }
+    }
+}
+
 // ggml_compute_forward_geglu_erf
 
 static void ggml_compute_forward_geglu_erf_f32(
@@ -4599,6 +4761,7 @@ void ggml_compute_forward_out_prod(
         case GGML_TYPE_Q5_0:
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
+        case GGML_TYPE_MXFP4:
         case GGML_TYPE_Q2_K:
         case GGML_TYPE_Q3_K:
         case GGML_TYPE_Q4_K:
@@ -4873,6 +5036,7 @@ void ggml_compute_forward_set(
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_Q8_1:
+        case GGML_TYPE_MXFP4:
         case GGML_TYPE_Q2_K:
         case GGML_TYPE_Q3_K:
         case GGML_TYPE_Q4_K:
@@ -5134,6 +5298,7 @@ void ggml_compute_forward_get_rows(
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_Q8_1:
+        case GGML_TYPE_MXFP4:
         case GGML_TYPE_Q2_K:
         case GGML_TYPE_Q3_K:
         case GGML_TYPE_Q4_K:
@@ -5523,6 +5688,7 @@ static void ggml_compute_forward_soft_max_f32(
 
     const ggml_tensor * src0 = dst->src[0];
     const ggml_tensor * src1 = dst->src[1];
+    const ggml_tensor * src2 = dst->src[2];
 
     assert(ggml_is_contiguous(dst));
     assert(ggml_are_same_shape(src0, dst));
@@ -5557,6 +5723,9 @@ static void ggml_compute_forward_soft_max_f32(
 
     const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
 
+    // sinks
+    const float * sk = src2 ? (float *)((char *) src2->data) : nullptr;
+
     for (int64_t i03 = 0; i03 < ne03; i03++) {
         for (int64_t i02 = 0; i02 < ne02; i02++) {
             for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
@@ -5599,9 +5768,18 @@ static void ggml_compute_forward_soft_max_f32(
                 float max = -INFINITY;
                 ggml_vec_max_f32(ne00, &max, wp);
 
+                // if we have sinks, make a correction as if they were included in the softmax
+                if (sk) {
+                    max = MAX(max, sk[i02]);
+                }
+
                 ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
                 assert(sum > 0.0);
 
+                if (sk) {
+                    sum += (ggml_float) expf(sk[i02] - max);
+                }
+
                 sum = 1.0/sum;
                 ggml_vec_scale_f32(ne00, dp, sum);
 
@@ -5836,6 +6014,7 @@ void ggml_compute_forward_clamp(
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_Q8_1:
+        case GGML_TYPE_MXFP4:
         case GGML_TYPE_Q2_K:
         case GGML_TYPE_Q3_K:
         case GGML_TYPE_Q4_K:
@@ -7989,12 +8168,14 @@ void ggml_compute_forward_argsort(
 
 static void ggml_compute_forward_flash_attn_ext_f16(
         const ggml_compute_params * params,
-        const ggml_tensor * q,
-        const ggml_tensor * k,
-        const ggml_tensor * v,
-        const ggml_tensor * mask,
         ggml_tensor * dst) {
 
+    const ggml_tensor * q     = dst->src[0];
+    const ggml_tensor * k     = dst->src[1];
+    const ggml_tensor * v     = dst->src[2];
+    const ggml_tensor * mask  = dst->src[3];
+    const ggml_tensor * sinks = dst->src[4];
+
     GGML_TENSOR_LOCALS(int64_t, neq, q,   ne)
     GGML_TENSOR_LOCALS(size_t,  nbq, q,   nb)
     GGML_TENSOR_LOCALS(int64_t, nek, k,   ne)
@@ -8189,6 +8370,23 @@ static void ggml_compute_forward_flash_attn_ext_f16(
             }
         }
 
+        // sinks
+        if (sinks) {
+            const float s = ((float *)((char *) sinks->data))[h];
+
+            float ms = 1.0f;
+            float vs = 1.0f;
+
+            if (s > M) {
+                ms = expf(M - s);
+                ggml_vec_scale_f32(DV, VKQ32, ms);
+            } else {
+                vs = expf(s - M);
+            }
+
+            S = S*ms + vs;
+        }
+
         // V /= S
         const float S_inv = 1.0f/S;
         ggml_vec_scale_f32(DV, VKQ32, S_inv);
@@ -8208,17 +8406,13 @@ static void ggml_compute_forward_flash_attn_ext_f16(
 
 void ggml_compute_forward_flash_attn_ext(
         const ggml_compute_params * params,
-        const ggml_tensor * q,
-        const ggml_tensor * k,
-        const ggml_tensor * v,
-        const ggml_tensor * mask,
         ggml_tensor * dst) {
     switch (dst->op_params[3]) {
         case GGML_PREC_DEFAULT:
         case GGML_PREC_F32:
             {
                 // uses F32 accumulators
-                ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
+                ggml_compute_forward_flash_attn_ext_f16(params, dst);
             } break;
         default:
             {
@@ -9080,6 +9274,10 @@ void ggml_compute_forward_glu(
             {
                 ggml_compute_forward_swiglu(params, dst);
             } break;
+        case GGML_GLU_OP_SWIGLU_OAI:
+            {
+                ggml_compute_forward_swiglu_oai(params, dst);
+            } break;
         case GGML_GLU_OP_GEGLU_ERF:
             {
                 ggml_compute_forward_geglu_erf(params, dst);
index 3a32ec20dba2b7b29cbfe03759725fe293e2ed5c..f154afb4624980f3d8924aa5a6c47a1d555b2b55 100644 (file)
@@ -29,6 +29,7 @@ extern "C" {
 
 void ggml_compute_forward_dup(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_add(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_add_id(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_add1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_acc(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_sum(const struct ggml_compute_params * params, struct ggml_tensor * dst);
@@ -82,13 +83,7 @@ void ggml_compute_forward_arange(const struct ggml_compute_params * params, stru
 void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
-void ggml_compute_forward_flash_attn_ext(
-    const struct ggml_compute_params * params,
-    const struct ggml_tensor * q,
-    const struct ggml_tensor * k,
-    const struct ggml_tensor * v,
-    const struct ggml_tensor * mask,
-    struct ggml_tensor * dst);
+void ggml_compute_forward_flash_attn_ext(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_flash_attn_back(
         const struct ggml_compute_params * params,
         const bool masked,
index ee35ab42fda07a7e9f9cb9348b8f93455a5b2424..365cb36d2d764353f274882ee1f6323179a4cce4 100644 (file)
@@ -46,6 +46,10 @@ void quantize_row_q8_1_generic(const float * GGML_RESTRICT x, void * GGML_RESTRI
     quantize_row_q8_1_ref(x, y, k);
 }
 
+void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
+    quantize_row_mxfp4_ref(x, y, k);
+}
+
 //
 // 2-6 bit quantization in super-blocks
 //
@@ -181,6 +185,37 @@ void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, c
     *s = sumf;
 }
 
+void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+    assert(nrc == 1);
+    UNUSED(nrc);
+    UNUSED(bx);
+    UNUSED(by);
+    UNUSED(bs);
+    assert(n % QK_MXFP4 == 0);
+    static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
+
+    const block_mxfp4 * GGML_RESTRICT x = vx;
+    const block_q8_0 * GGML_RESTRICT y = vy;
+
+    const int nb = n / QK_MXFP4;
+
+    int ib = 0;
+    float sumf = 0;
+
+    for (; ib < nb; ++ib) {
+        const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
+
+        int sumi1 = 0;
+        int sumi2 = 0;
+        for (int j = 0; j < QK_MXFP4/2; ++j) {
+            sumi1 += y[ib].qs[j +          0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
+            sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >>  4];
+        }
+        sumf += d * (sumi1 + sumi2);
+    }
+    *s = sumf;
+}
+
 void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
     const int qk = QK8_0;
     const int nb = n / qk;
index dc4342c87f592c5b7e25775e57ae1d02f73e3098..d83eb1b144d478bcd5f7a8db008e6a4f6fc43ead 100644 (file)
@@ -19,6 +19,8 @@ void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in
 void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
 void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
 
+void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+
 void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
 void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
 void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
@@ -39,6 +41,8 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
 void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 
+void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
 void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
@@ -67,8 +71,12 @@ void ggml_vec_dot_q4_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, c
 void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 void ggml_vec_dot_q5_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 void ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
+void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
 void ggml_vec_dot_tq1_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 void ggml_vec_dot_tq2_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
 void ggml_vec_dot_q2_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 void ggml_vec_dot_q3_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
 void ggml_vec_dot_q4_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
index d18783a00a1a5790f80af40a9a8c032e190f607b..2250d93cb00d13ad18b5b9c32e1bea8a514faefa 100644 (file)
@@ -55,7 +55,22 @@ inline static void ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x)
 
 inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const ggml_fp16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
 inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
-inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i] + y[i]; }
+
+inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) {
+    int i = 0;
+#if defined(__AVX2__)
+    for (; i + 7 < n; i += 8) {
+        __m256 vx = _mm256_loadu_ps(x + i);
+        __m256 vy = _mm256_loadu_ps(y + i);
+        __m256 vz = _mm256_add_ps(vx, vy);
+        _mm256_storeu_ps(z + i, vz);
+    }
+#endif
+    for (; i < n; ++i) {
+        z[i] = x[i] + y[i];
+    }
+}
+
 inline static void ggml_vec_add_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
     for (int i = 0; i < n; ++i) {
         z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) + GGML_CPU_FP16_TO_FP32(y[i]));
@@ -992,9 +1007,9 @@ void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float *
 
 inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
     for (int i = 0; i < n; ++i) {
-        float v = GGML_CPU_FP16_TO_FP32(x[i]);
-        float w = GGML_CPU_FP16_TO_FP32(g[i]);
-        y[i] = GGML_CPU_FP32_TO_FP16((v/(1.0f + expf(-v))) * w);
+        float xi = GGML_CPU_FP16_TO_FP32(x[i]);
+        float gi = GGML_CPU_FP16_TO_FP32(g[i]);
+        y[i] = GGML_CPU_FP32_TO_FP16((xi/(1.0f + expf(-xi))) * gi);
     }
 }
 
diff --git a/src/ggml-cuda/add-id.cu b/src/ggml-cuda/add-id.cu
new file mode 100644 (file)
index 0000000..8bed62a
--- /dev/null
@@ -0,0 +1,58 @@
+#include "add-id.cuh"
+
+static __global__ void add_id_kernel(
+        const float * src0, const float * src1, const int32_t * src2, float * dst,
+        int64_t ne0, int64_t ne1,
+        size_t nb01, size_t nb02,
+        size_t nb11,
+        size_t nb21
+    ) {
+
+    const int64_t i1 = blockIdx.x;
+    const int64_t i2 = blockIdx.y;
+
+    const int i11 = *(int32_t *) ((char *) src2 + i1*sizeof(int32_t) + i2*nb21);
+
+    const size_t nb1 = ne0 * sizeof(float);
+    const size_t nb2 = ne1 * nb1;
+
+    float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2);
+    const float * src0_row = (const float *)((char *)src0 +  i1*nb01 + i2*nb02);
+    const float * src1_row = (const float *)((char *)src1 + i11*nb11);
+
+    for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
+        dst_row[i0] = src0_row[i0] + src1_row[i0];
+    }
+}
+
+void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    const ggml_tensor * src2 = dst->src[2];
+
+    GGML_TENSOR_TERNARY_OP_LOCALS
+
+    GGML_ASSERT(dst->type  == GGML_TYPE_F32);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT(src2->type == GGML_TYPE_I32);
+
+    GGML_ASSERT(nb00 == sizeof(float));
+    GGML_ASSERT(nb10 == sizeof(float));
+    GGML_ASSERT(nb20 == sizeof(int32_t));
+
+    const float * src0_d = (const float *)src0->data;
+    const float * src1_d = (const float *)src1->data;
+    const int32_t * src2_d = (const int32_t *)src2->data;
+    float * dst_d = (float *)dst->data;
+
+    int threads = std::min((int)ne00, 768); // cols
+    dim3 blocks(ne01, ne02); // n_experts_used, n_tokens
+    add_id_kernel<<<blocks, threads, 0, ctx.stream()>>>(
+        src0_d, src1_d, src2_d, dst_d,
+        ne0, ne1,
+        nb01, nb02,
+        nb11,
+        nb21
+    );
+}
diff --git a/src/ggml-cuda/add-id.cuh b/src/ggml-cuda/add-id.cuh
new file mode 100644 (file)
index 0000000..30b1721
--- /dev/null
@@ -0,0 +1,3 @@
+#include "common.cuh"
+
+void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 7fb04d51b770f9dcaeff3b6fb97d07a347152e4a..8f27255476d5981acc07d397cfb1df7bb82eae95 100644 (file)
@@ -1,6 +1,7 @@
 #pragma once
 
 #include "ggml.h"
+#include "ggml-impl.h"
 #include "ggml-cuda.h"
 
 #include <cstdint>
@@ -549,6 +550,24 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
 #endif // defined(GGML_USE_HIP)
 }
 
+static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
+#if CUDART_VERSION >= 12080
+    const nv_bfloat16 e = __nv_cvt_e8m0_to_bf16raw(x);
+    return (float) e;
+#else
+    uint32_t bits;
+    if (x == 0) {
+        bits = 0x00400000;
+    } else {
+        bits = (uint32_t) x << 23;
+    }
+
+    float result;
+    memcpy(&result, &bits, sizeof(float));
+    return result;
+#endif // CUDART_VERSION >= 12050
+}
+
 typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
 
 static __device__ __forceinline__ float get_alibi_slope(
@@ -607,6 +626,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_Q8_0> {
     static constexpr int qi = QI8_0;
 };
 
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_MXFP4> {
+    static constexpr int qk = QK_MXFP4;
+    static constexpr int qr = QR_MXFP4;
+    static constexpr int qi = QI_MXFP4;
+};
+
 template<>
 struct ggml_cuda_type_traits<GGML_TYPE_Q2_K> {
     static constexpr int qk = QK_K;
index 15c927861f03d65e01ccf01837135d7f772a3286..e3beddbc1b23b6c1831ba9b9232572b265f92d27 100644 (file)
@@ -465,6 +465,24 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
     }
 }
 
+template<typename dst_t>
+static __global__ void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+    const int64_t i   = blockIdx.x;
+    const block_mxfp4 * x = (const block_mxfp4 *) vx + i*(QK_K/QK_MXFP4);
+
+    const int64_t tid = threadIdx.x;
+    const int64_t il = tid/8; // 0...3
+    const int64_t ib = tid%8; // 0...7
+    dst_t * y = yy + i*QK_K + 32*ib + 4*il;
+    const uint8_t  * q4 = x[ib].qs + 4*il;
+    const float d = ggml_cuda_e8m0_to_fp32(x[ib].e);
+    for (int j = 0; j < 4; ++j) {
+        y[j+ 0] = d * kvalues_mxfp4[q4[j] & 0xf]*0.5f;
+        y[j+16] = d * kvalues_mxfp4[q4[j] >>  4]*0.5f;
+    }
+}
+
 template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
 static void dequantize_block_cuda(const void * vx, dst_t * y,
         const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
@@ -588,6 +606,12 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int64_t
     dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
 }
 
+template<typename dst_t>
+static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+    const int nb = (k + QK_K - 1) / QK_K;
+    dequantize_block_mxfp4<<<nb, 32, 0, stream>>>(vx, y);
+}
+
 template <typename src_t, typename dst_t>
 static __global__ void convert_unary(
         const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
@@ -677,6 +701,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
             return dequantize_row_iq4_xs_cuda;
         case GGML_TYPE_IQ3_S:
             return dequantize_row_iq3_s_cuda;
+        case GGML_TYPE_MXFP4:
+            return dequantize_row_mxfp4_cuda;
         case GGML_TYPE_F32:
             return convert_unary_cont_cuda<float>;
         case GGML_TYPE_BF16:
@@ -726,6 +752,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
             return dequantize_row_iq4_xs_cuda;
         case GGML_TYPE_IQ3_S:
             return dequantize_row_iq3_s_cuda;
+        case GGML_TYPE_MXFP4:
+            return dequantize_row_mxfp4_cuda;
         case GGML_TYPE_F16:
             return convert_unary_cont_cuda<half>;
         case GGML_TYPE_BF16:
index b6db446c6feaf2e81d2ec5c593613b6284dc5f91..e46f0e2081bdfb73c5bfe6582a5fff485c8d2c70 100644 (file)
@@ -15,6 +15,7 @@ typedef void (* fattn_kernel_t)(
         const char * __restrict__ K,
         const char * __restrict__ V,
         const char * __restrict__ mask,
+        const char * __restrict__ sinks,
         const int  * __restrict__ KV_max,
         float      * __restrict__ dst,
         float2     * __restrict__ dst_meta,
@@ -736,7 +737,8 @@ void launch_fattn(
 
     GGML_ASSERT(V || is_mla);
 
-    const ggml_tensor * mask = dst->src[3];
+    const ggml_tensor * mask  = dst->src[3];
+    const ggml_tensor * sinks = dst->src[4];
 
     ggml_tensor * KQV = dst;
 
@@ -940,6 +942,7 @@ void launch_fattn(
         K_data,
         V_data,
         mask ? ((const char *) mask->data) : nullptr,
+        sinks ? ((const char *) sinks->data) : nullptr,
         KV_max.ptr,
         !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
         scale, max_bias, m0, m1, n_head_log2, logit_softcap,
index a86b95428f6ff6baedad28f09b4cd0374450e346..e7570f9d3b830340ed113a4c52429d304324127a 100644 (file)
@@ -1206,6 +1206,7 @@ static __global__ void flash_attn_ext_f16(
         const char * __restrict__ K,
         const char * __restrict__ V,
         const char * __restrict__ mask,
+        const char * __restrict__ sinks,
         const int  * __restrict__ KV_max,
         float      * __restrict__ dst,
         float2     * __restrict__ dst_meta,
@@ -1267,6 +1268,7 @@ static __global__ void flash_attn_ext_f16(
     // kb0 == k start index when in the output tile.
     int kb0_start = kbc % iter_k;
     int kb0_stop  = min(iter_k, kb0_start + kbc_stop - kbc);
+
     while (kbc < kbc_stop && kb0_stop == iter_k) {
         const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
         const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
@@ -1340,7 +1342,7 @@ static __global__ void flash_attn_ext_f16(
         (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
          ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
 #else
-    GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+    GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
     GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
     GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
     GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
index 9d0b24ae7ec73fbb125b102c4ebc4959393c3de6..0fcfaa32ea4645ef924cb2b961845fbd9e51cfec 100644 (file)
@@ -13,6 +13,7 @@ static __global__ void flash_attn_tile_ext_f16(
         const char * __restrict__ K,
         const char * __restrict__ V,
         const char * __restrict__ mask,
+        const char * __restrict__ sinks,
         const int  * __restrict__ KV_max,
         float      * __restrict__ dst,
         float2     * __restrict__ dst_meta,
@@ -272,7 +273,7 @@ static __global__ void flash_attn_tile_ext_f16(
         }
     }
 #else
-    GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+    GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
     GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
     GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
     GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
index be72f76fb6538698b8c26caf623089a128a509cc..23550cbbd9736a9410cd137d66b057d953928529 100644 (file)
@@ -13,6 +13,7 @@ static __global__ void flash_attn_tile_ext_f32(
         const char * __restrict__ K,
         const char * __restrict__ V,
         const char * __restrict__ mask,
+        const char * __restrict__ sinks,
         const int  * __restrict__ KV_max,
         float      * __restrict__ dst,
         float2     * __restrict__ dst_meta,
@@ -37,7 +38,7 @@ static __global__ void flash_attn_tile_ext_f32(
     return;
 #endif // FP16_MMA_AVAILABLE
     if (use_logit_softcap && !(D == 128 || D == 256)) {
-        GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+        GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
         GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
         GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
         GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
index a2df2f66be0c41602af9651e3146201defdf581f..b05f682cd3b4db8385abfb5eed134b92165270cc 100644 (file)
@@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f16(
         const char * __restrict__ K,
         const char * __restrict__ V,
         const char * __restrict__ mask,
+        const char * __restrict__ sinks,
         const int  * __restrict__ KV_max,
         float      * __restrict__ dst,
         float2     * __restrict__ dst_meta,
@@ -61,7 +62,8 @@ static __global__ void flash_attn_vec_ext_f16(
     K += nb13*sequence + nb12*(head / gqa_ratio);
     V += nb23*sequence + nb22*(head / gqa_ratio);
 
-    const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
+    const half  * maskh  = (const half  *) (mask + nb33*(sequence % ne33) + nb31*ic0);
+    const float * sinksf = (const float *) (sinks);
 
     const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
     const half  slopeh = __float2half(slopef);
@@ -75,11 +77,12 @@ static __global__ void flash_attn_vec_ext_f16(
     half2 * KQ2 = (half2 *) KQ;
 
     half kqmax[ncols];
+    half kqsum[ncols];
 #pragma unroll
     for (int j = 0; j < ncols; ++j) {
         kqmax[j] = -HALF_MAX_HALF;
+        kqsum[j] = 0.0f;
     }
-    half kqsum[ncols] = {0.0f};
 
     __shared__ half kqmax_shared[ncols][WARP_SIZE];
     __shared__ half kqsum_shared[ncols][WARP_SIZE];
@@ -283,6 +286,39 @@ static __global__ void flash_attn_vec_ext_f16(
         __syncthreads();
     }
 
+    if (sinksf && blockIdx.y == 0) {
+        const half sink = __float2half(sinksf[head]);
+
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            if (threadIdx.x == 0) {
+                kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink);
+            }
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            half kqmax_new_j = kqmax_shared[j][threadIdx.x];
+            kqmax_new_j = warp_reduce_max(kqmax_new_j);
+
+            const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
+            kqmax[j] = kqmax_new_j;
+
+            const half val = hexp(sink - kqmax[j]);
+            kqsum[j] = kqsum[j]*KQ_max_scale;
+
+            if (tid == 0) {
+                kqsum[j] += val;
+            }
+
+            VKQ[j] *= __half2half2(KQ_max_scale);
+        }
+
+        __syncthreads();
+    }
+
 #pragma unroll
     for (int j = 0; j < ncols; ++j) {
         kqsum[j] = warp_reduce_sum((float)kqsum[j]);
@@ -313,7 +349,7 @@ static __global__ void flash_attn_vec_ext_f16(
         dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
     }
 #else
-    GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+    GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
     GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
     GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
     GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
index 9ab0fc133b7a201ddf93d3667b8d0525628f249b..d6d0bfb744b74d21587d6e705eeb7a0b5fd7fd00 100644 (file)
@@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f32(
         const char * __restrict__ K,
         const char * __restrict__ V,
         const char * __restrict__ mask,
+        const char * __restrict__ sinks,
         const int  * __restrict__ KV_max,
         float      * __restrict__ dst,
         float2     * __restrict__ dst_meta,
@@ -72,7 +73,8 @@ static __global__ void flash_attn_vec_ext_f32(
     K += nb13*sequence + nb12*(head / gqa_ratio);
     V += nb23*sequence + nb22*(head / gqa_ratio);
 
-    const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
+    const half  * maskh  = (const half  *) (mask + nb33*(sequence % ne33) + nb31*ic0);
+    const float * sinksf = (const float *) (sinks);
 
     const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
 
@@ -88,11 +90,12 @@ static __global__ void flash_attn_vec_ext_f32(
     }
 
     float kqmax[ncols];
+    float kqsum[ncols];
 #pragma unroll
     for (int j = 0; j < ncols; ++j) {
         kqmax[j] = -FLT_MAX/2.0f;
+        kqsum[j] = 0.0f;
     }
-    float kqsum[ncols] = {0.0f};
 
     __shared__ float kqmax_shared[ncols][WARP_SIZE];
     __shared__ float kqsum_shared[ncols][WARP_SIZE];
@@ -279,6 +282,39 @@ static __global__ void flash_attn_vec_ext_f32(
         __syncthreads();
     }
 
+    if (sinksf && blockIdx.y == 0) {
+        const float sink = sinksf[head];
+
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            if (threadIdx.x == 0) {
+                kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink);
+            }
+        }
+
+        __syncthreads();
+
+#pragma unroll
+        for (int j = 0; j < ncols; ++j) {
+            float kqmax_new_j = kqmax_shared[j][threadIdx.x];
+            kqmax_new_j = warp_reduce_max(kqmax_new_j);
+
+            const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j);
+            kqmax[j] = kqmax_new_j;
+
+            const float val = expf(sink - kqmax[j]);
+            kqsum[j] = kqsum[j]*KQ_max_scale;
+
+            if (tid == 0) {
+                kqsum[j] += val;
+            }
+
+            VKQ[j] *= KQ_max_scale;
+        }
+
+        __syncthreads();
+    }
+
 #pragma unroll
     for (int j = 0; j < ncols; ++j) {
         kqsum[j] = warp_reduce_sum(kqsum[j]);
index 40554204b62f3a4da6c5ccfa3119c2c3178a93ad..93d4d810218d78a6560b5b9f7a97663fb806d833 100644 (file)
@@ -29,6 +29,7 @@ static __global__ void flash_attn_ext_f16(
         const char * __restrict__ K,
         const char * __restrict__ V,
         const char * __restrict__ mask,
+        const char * __restrict__ sinks,
         const int  * __restrict__ KV_max,
         float      * __restrict__ dst,
         float2     * __restrict__ dst_meta,
@@ -423,7 +424,7 @@ static __global__ void flash_attn_ext_f16(
         dst_meta[j_dst_unrolled] = dst_meta_val;
     }
 #else
-    GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+    GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
     GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
     GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
     GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
index 039c54e015ea67e69e5a7a2f5333185f393f2cec..656e04a47352e09e2796aacdd345dbf460564844 100644 (file)
@@ -269,17 +269,28 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
 }
 
 void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * KQV  = dst;
-    const ggml_tensor * Q    = dst->src[0];
-    const ggml_tensor * K    = dst->src[1];
-    const ggml_tensor * V    = dst->src[2];
-    const ggml_tensor * mask = dst->src[3];
+    const ggml_tensor * KQV   = dst;
+    const ggml_tensor * Q     = dst->src[0];
+    const ggml_tensor * K     = dst->src[1];
+    const ggml_tensor * V     = dst->src[2];
+    const ggml_tensor * mask  = dst->src[3];
+    const ggml_tensor * sinks = dst->src[4];
 
     ggml_cuda_set_device(ctx.device);
     const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
     const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
     const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
 
+    // TODO: currently only vec implementation for sinks is supported [TAG_ATTN_SINKS]
+    if (sinks) {
+        if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
+            ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
+        } else {
+            ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
+        }
+        return;
+    }
+
 #if defined(GGML_HIP_ROCWMMA_FATTN)
     if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) {
         ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
index 8885fb7fbdd2f03b55b8dc84d239182fbddd48a6..60e481b95af0340128ecf94231fd4f16a60a33d7 100644 (file)
@@ -4,6 +4,7 @@
 
 #include "ggml-cuda/common.cuh"
 #include "ggml-cuda/acc.cuh"
+#include "ggml-cuda/add-id.cuh"
 #include "ggml-cuda/arange.cuh"
 #include "ggml-cuda/argmax.cuh"
 #include "ggml-cuda/argsort.cuh"
@@ -2259,6 +2260,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_ADD1: // TODO: more efficient implementation
             ggml_cuda_op_add(ctx, dst);
             break;
+        case GGML_OP_ADD_ID:
+            ggml_cuda_op_add_id(ctx, dst);
+            break;
         case GGML_OP_SUB:
             ggml_cuda_op_sub(ctx, dst);
             break;
@@ -2333,6 +2337,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
                 case GGML_GLU_OP_SWIGLU:
                     ggml_cuda_op_swiglu(ctx, dst);
                     break;
+                case GGML_GLU_OP_SWIGLU_OAI:
+                    ggml_cuda_op_swiglu_oai(ctx, dst);
+                    break;
                 case GGML_GLU_OP_GEGLU_ERF:
                     ggml_cuda_op_geglu_erf(ctx, dst);
                     break;
@@ -2607,6 +2614,9 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
 
     const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
     const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
+    const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
+    const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
+    const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
 
     for (int i = 0; i < cgraph->n_nodes; i++) {
         ggml_tensor * node = cgraph->nodes[i];
@@ -2629,7 +2639,13 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud
 #endif
         }
 
-        if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1 && (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) && (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true)) {
+        if (node->op == GGML_OP_ADD &&
+            node->src[1] && node->src[1]->ne[1] > 1 &&
+            (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
+            (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
+            strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
+            strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
+            strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0) {
             // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
             // by means of matching node names. See
             // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
@@ -3227,6 +3243,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                 case GGML_GLU_OP_REGLU:
                 case GGML_GLU_OP_GEGLU:
                 case GGML_GLU_OP_SWIGLU:
+                case GGML_GLU_OP_SWIGLU_OAI:
                 case GGML_GLU_OP_GEGLU_ERF:
                 case GGML_GLU_OP_GEGLU_QUICK:
                     return ggml_is_contiguous_1(op->src[0]);
@@ -3277,6 +3294,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                     case GGML_TYPE_Q5_0:
                     case GGML_TYPE_Q5_1:
                     case GGML_TYPE_Q8_0:
+                    case GGML_TYPE_MXFP4:
                     case GGML_TYPE_Q2_K:
                     case GGML_TYPE_Q3_K:
                     case GGML_TYPE_Q4_K:
@@ -3423,6 +3441,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_PERMUTE:
         case GGML_OP_TRANSPOSE:
         case GGML_OP_ADD:
+        case GGML_OP_ADD_ID:
         case GGML_OP_ADD1:
         case GGML_OP_SUB:
         case GGML_OP_MUL:
@@ -3503,6 +3522,10 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                 const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
                 return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
             }
+            // TODO: more general-purpose attention sink support [TAG_ATTN_SINKS]
+            if (op->src[4] && op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) { // currently only sinks for head_size 64 and 128 are supported
+                return false;
+            }
             if (op->src[0]->ne[0] == 192) {
                 return false;
             }
index 73b98133438fc18f4d739cb384a8eca755a400d4..16bb9bec97d25afb1b659ea1249fb16d40d34e07 100644 (file)
@@ -1,7 +1,5 @@
 #include "im2col.cuh"
 
-#define MIN(a, b) (a) < (b) ? (a) : (b)
-
 #define MAX_GRIDDIM_Z 65535
 
 template <typename T>
@@ -38,6 +36,9 @@ static  __global__ void im2col_kernel(
             dst[offset_dst] = x[offset_src + iih * IW + iiw];
         }
     }
+
+    GGML_UNUSED(IC);
+    GGML_UNUSED(KH);
 }
 
 // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
index d4954fbe69e11b7db243ee286e71f9b245e223c7..8954a3831045630692344d7d9745dc8418ae928d 100644 (file)
@@ -20,6 +20,9 @@ static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, con
         case GGML_TYPE_Q8_0:
             mul_mat_q_case<GGML_TYPE_Q8_0>(ctx, args, stream);
             break;
+        case GGML_TYPE_MXFP4:
+            mul_mat_q_case<GGML_TYPE_MXFP4>(ctx, args, stream);
+            break;
         case GGML_TYPE_Q2_K:
             mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream);
             break;
@@ -282,6 +285,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
         case GGML_TYPE_Q5_0:
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
+        case GGML_TYPE_MXFP4:
         case GGML_TYPE_Q2_K:
         case GGML_TYPE_Q3_K:
         case GGML_TYPE_Q4_K:
index 04a8d80e1211b1d22ca1f356cab40bb21b635ee2..1634725c20a5aa630986e32af07aa154db3387bf 100644 (file)
@@ -58,6 +58,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
             return MMQ_Q8_1_DS_LAYOUT_DS4;
         case GGML_TYPE_Q8_0:
             return MMQ_Q8_1_DS_LAYOUT_D4;
+        case GGML_TYPE_MXFP4:
+            return MMQ_Q8_1_DS_LAYOUT_D4;
         case GGML_TYPE_Q2_K:
             return MMQ_Q8_1_DS_LAYOUT_D2S6;
         case GGML_TYPE_Q3_K:
@@ -170,6 +172,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
         case GGML_TYPE_Q5_0:    return MMQ_DP4A_TXS_Q8_0;
         case GGML_TYPE_Q5_1:    return MMQ_DP4A_TXS_Q8_1;
         case GGML_TYPE_Q8_0:    return MMQ_DP4A_TXS_Q8_0;
+        case GGML_TYPE_MXFP4:   return MMQ_DP4A_TXS_Q8_1;
         case GGML_TYPE_Q2_K:    return MMQ_DP4A_TXS_Q2_K;
         case GGML_TYPE_Q3_K:    return MMQ_DP4A_TXS_Q3_K;
         case GGML_TYPE_Q4_K:    return MMQ_DP4A_TXS_Q4_K;
@@ -206,6 +209,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
         case GGML_TYPE_Q5_0:    return MMQ_MMA_TILE_X_K_Q8_0;
         case GGML_TYPE_Q5_1:    return MMQ_MMA_TILE_X_K_Q8_1;
         case GGML_TYPE_Q8_0:    return MMQ_MMA_TILE_X_K_Q8_0;
+        case GGML_TYPE_MXFP4:   return MMQ_MMA_TILE_X_K_Q8_1;
         case GGML_TYPE_Q2_K:    return MMQ_MMA_TILE_X_K_Q2_K;
         case GGML_TYPE_Q3_K:    return MMQ_MMA_TILE_X_K_Q3_K;
         case GGML_TYPE_Q4_K:    return MMQ_MMA_TILE_X_K_Q8_1;
@@ -692,6 +696,71 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     }
 }
 
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_mxfp4(
+    const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+    constexpr int nwarps = mmq_get_nwarps_device();
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y);
+    int   * x_qs = (int   *)  x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+
+    constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4);
+    constexpr int nrows = warp_size / threads_per_row;
+    const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
+    const int kbx  = txi / QI_MXFP4;
+    const int kqsx = txi % QI_MXFP4;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+        int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbx;
+
+        const int aux_q4 = get_int_b1(bxi->qs, kqsx);
+        const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
+        const int k0 = kbx * (2 * QI_MXFP4) + kqsx;
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0]        = v.x;
+        x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y;
+#else
+        x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0]        = v.x;
+        x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+    }
+
+    constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4;
+    constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
+    const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
+        int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
+
+        if (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd;
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_1                 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
+#else
+        x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+    }
+}
+
 template <int mmq_x, int mmq_y>
 static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
@@ -2268,7 +2337,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
         const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx;
 
         const int aux_q4 = get_int_b2(bxi->qs, kqsx);
-        const int2 v = get_int_from_table_16(aux_q4);
+        const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
         const int k0 = kbx * (2 * QI4_NL) + kqsx;
 
 #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
@@ -2707,7 +2776,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
         const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;
 
         const int aux_q4 = get_int_b4(bxi->qs, kqsx);
-        const int2 v = get_int_from_table_16(aux_q4);
+        const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
         const int k0 = 8 * (kqsx / 4) + kqsx % 4;
 
 #if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
@@ -2863,6 +2932,14 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q8_0> {
     static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
 };
 
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
+    static constexpr int              vdr          = VDR_MXFP4_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_mxfp4<mmq_y, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
+};
+
 template <int mmq_x, int mmq_y, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {
     static constexpr int              vdr          = VDR_Q2_K_Q8_1_MMQ;
@@ -3642,6 +3719,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
 extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
 extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
 extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
+extern DECL_MMQ_CASE(GGML_TYPE_MXFP4);
 extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
 extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
 extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
index dc7adf509fac086ef4d4a2dad33fec8261afea02..5c8e5c4a7eeb589184137b37467327c06f94a8ac 100644 (file)
@@ -13,6 +13,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
         case GGML_TYPE_Q5_0:    return vec_dot_q5_0_q8_1;
         case GGML_TYPE_Q5_1:    return vec_dot_q5_1_q8_1;
         case GGML_TYPE_Q8_0:    return vec_dot_q8_0_q8_1;
+        case GGML_TYPE_MXFP4:   return vec_dot_mxfp4_q8_1;
         case GGML_TYPE_Q2_K:    return vec_dot_q2_K_q8_1;
         case GGML_TYPE_Q3_K:    return vec_dot_q3_K_q8_1;
         case GGML_TYPE_Q4_K:    return vec_dot_q4_K_q8_1;
@@ -38,6 +39,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
         case GGML_TYPE_Q5_0:    return VDR_Q5_0_Q8_1_MMVQ;
         case GGML_TYPE_Q5_1:    return VDR_Q5_1_Q8_1_MMVQ;
         case GGML_TYPE_Q8_0:    return VDR_Q8_0_Q8_1_MMVQ;
+        case GGML_TYPE_MXFP4:   return VDR_MXFP4_Q8_1_MMVQ;
         case GGML_TYPE_Q2_K:    return VDR_Q2_K_Q8_1_MMVQ;
         case GGML_TYPE_Q3_K:    return VDR_Q3_K_Q8_1_MMVQ;
         case GGML_TYPE_Q4_K:    return VDR_Q4_K_Q8_1_MMVQ;
@@ -384,6 +386,13 @@ static void mul_mat_vec_q_switch_type(
                  nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
                  stream);
             break;
+        case GGML_TYPE_MXFP4:
+            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_MXFP4>
+                (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+                 stream);
+            break;
         case GGML_TYPE_Q2_K:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
                 (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
index 14543e978cf0f9f5a7d59a9741d65add2ebbaa02..eeacde0bdb126c173231b4bf748ab2da1a0d9cd6 100644 (file)
@@ -45,7 +45,7 @@ struct soft_max_params {
 #endif // __clang__
 template <bool use_shared, int ncols_template, int block_size_template, typename T>
 static __global__ void soft_max_f32(
-        const float * x, const T * mask, float * dst, const soft_max_params p) {
+        const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params p) {
     const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
 
     const int tid  = threadIdx.x;
@@ -77,7 +77,7 @@ static __global__ void soft_max_f32(
     // shared memory buffer to cache values between iterations:
     float * vals = use_shared ? buf_iw + WARP_SIZE : dst;
 
-    float max_val = -INFINITY;
+    float max_val = sinks ? sinks[i02] : -INFINITY;
 
 #pragma unroll
     for (int col0 = 0; col0 < ncols; col0 += block_size) {
@@ -143,6 +143,10 @@ static __global__ void soft_max_f32(
         tmp = warp_reduce_sum(tmp);
     }
 
+    if (sinks) {
+        tmp += expf(sinks[i02] - max_val);
+    }
+
     const float inv_sum = 1.0f / tmp;
 
 #pragma unroll
@@ -183,7 +187,7 @@ static __global__ void soft_max_back_f32(
 }
 
 template<int... Ns, typename T>
-static void launch_soft_max_kernels(const float * x, const T * mask, float * dst,
+static void launch_soft_max_kernels(const float * x, const T * mask, const float * sinks, float * dst,
                              const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
 {
     const int id       = ggml_cuda_get_device();
@@ -196,7 +200,7 @@ static void launch_soft_max_kernels(const float * x, const T * mask, float * dst
         if (p.ncols == ncols) {
             CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
             soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
-                (x, mask, dst, p);
+                (x, mask, sinks, dst, p);
             return true;
         }
         return false;
@@ -209,12 +213,12 @@ static void launch_soft_max_kernels(const float * x, const T * mask, float * dst
 
     //default case
     CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
-    soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, dst, p);
+    soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, sinks, dst, p);
 }
 
 
 template<typename T>
-static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) {
+static void soft_max_f32_cuda(const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params & params, cudaStream_t stream) {
     int nth = WARP_SIZE;
     const int64_t ncols_x = params.ncols;
 
@@ -230,10 +234,10 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
 
 
     if (nbytes_shared <= smpbo) {
-        launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, dst, params, stream, block_dims, block_nums, nbytes_shared);
+        launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared);
     } else {
         const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
-        soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, params);
+        soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, sinks, dst, params);
     }
 }
 
@@ -249,9 +253,11 @@ static void soft_max_back_f32_cuda(
 void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const ggml_tensor * src1 = dst->src[1];
+    const ggml_tensor * src2 = dst->src[2];
 
     const float * src0_d = (const float *) src0->data;
     const void  * src1_d = src1 ? (const void *) src1->data : nullptr;
+    const void  * src2_d = src2 ? (const void *) src2->data : nullptr;
     float       *  dst_d = (float *) dst->data;
 
     cudaStream_t stream = ctx.stream();
@@ -309,9 +315,9 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     params.m1 = m1;
 
     if (use_f16) {
-        soft_max_f32_cuda(src0_d, (const half  *) src1_d, dst_d, params, stream);
+        soft_max_f32_cuda(src0_d, (const half  *) src1_d, (const float *) src2_d, dst_d, params, stream);
     } else {
-        soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, params, stream);
+        soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream);
     }
 }
 
diff --git a/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu b/src/ggml-cuda/template-instances/mmq-instance-mxfp4.cu
new file mode 100644 (file)
index 0000000..c14624c
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_MXFP4);
index 91c830c4dacc3816778650620fcc5eaae8fb5462..5aff8a876af2ce1f73f603f2ab81854e22f1a9de 100644 (file)
@@ -300,6 +300,81 @@ void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     ggml_cuda_op_unary_gated<op_gelu_quick>(ctx, dst);
 }
 
+// swiglu_oai
+
+template <typename T>
+static __global__ void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, float alpha, float limit) {
+    const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+
+    // perform base op and multiply with gate (either offset in same tensor or a separate one)
+    const int64_t j0 = (i / n) * o0 + (i % n);
+    const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
+
+    float xi = x[j0];
+    float gi = g[j1];
+    xi = fminf(xi, limit);
+    gi = fmaxf(fminf(gi, limit), -limit);
+
+    float out_glu = xi / (1.0f + expf(-xi * alpha));
+    out_glu = out_glu * (1.0f + gi);
+
+    dst[i] = out_glu;
+}
+
+template <typename T>
+static void swiglu_oai_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream) {
+    const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE;
+    swiglu_oai_kernel<<<num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream>>>(x, g, dst, k, n, o0, o1, alpha, limit);
+}
+
+void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+    void * src0_d = src0->data;
+    void * src1_d = src1 ? src1->data : src0->data;
+    const int64_t src0_o = src0->nb[1];
+    const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
+    void * dst_d = dst->data;
+    const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
+    cudaStream_t stream = ctx.stream();
+
+    GGML_ASSERT(ggml_is_contiguous_1(src0));
+    GGML_ASSERT(src0->nb[0] == ggml_element_size(src0));
+    GGML_ASSERT(ggml_is_contiguous(dst));
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->type == dst->type);
+    GGML_ASSERT(dst->ne[0] == nc);
+    GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
+
+    if (src1) {
+        GGML_ASSERT(ggml_is_contiguous_1(src1));
+        GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
+        GGML_ASSERT(src1->ne[0] == nc);
+        GGML_ASSERT(src0->type == src1->type);
+    }
+
+    //const int32_t swapped = ((const int32_t *) dst->op_params)[1];
+    const int32_t swapped = ggml_get_op_params_i32(dst, 1);
+    const float alpha = ggml_get_op_params_f32(dst, 2);
+    const float limit = ggml_get_op_params_f32(dst, 3);
+
+    float * src0_p = (float *) src0_d;
+    float * src1_p = (float *) src1_d;
+
+    if (!src1) {
+        src0_p += swapped ? nc : 0;
+        src1_p += swapped ? 0 : nc;
+    }
+
+    swiglu_oai_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream);
+}
+
 /* silu_back */
 
 static __device__ __forceinline__ float op_silu_back(float grad, float x) {
index cb14d16f8f3f56b82701b4598970581725bf86ee..da3caf1d8962e35f23fa9802d88f9e20050ecb39 100644 (file)
@@ -67,6 +67,8 @@ void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
+void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
 void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
 
 void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index ba195e1d100d320298cdf12442e8d5de2995e616..d8f9aa5ba62242041835e1a4425e27210f227903 100644 (file)
@@ -1,8 +1,20 @@
 #pragma once
 
 #include "common.cuh"
+
 #include <cstdint>
 
+static __device__ __forceinline__ int get_int_b1(const void * x, const int & i32) {
+    const uint8_t * x8 = (const uint8_t *) x;
+
+    int x32  = x8[4*i32 + 0] <<  0;
+    x32     |= x8[4*i32 + 1] <<  8;
+    x32     |= x8[4*i32 + 2] << 16;
+    x32     |= x8[4*i32 + 3] << 24;
+
+    return x32;
+}
+
 static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) {
     const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment
 
@@ -16,6 +28,20 @@ static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32
     return ((const int *) x)[i32]; // assume at least 4 byte alignment
 }
 
+static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * table) {
+    const int      q0_32  = (q4 >> 0) & 0x0F0F0F0F;
+    const int8_t * q0_8   = (const int8_t *) &q0_32;
+    const char4    val0_8 = make_char4(
+        table[q0_8[0]], table[q0_8[1]], table[q0_8[2]], table[q0_8[3]]);
+
+    const int      q1_32  = (q4 >> 4) & 0x0F0F0F0F;
+    const int8_t * q1_8   = (const int8_t *) &q1_32;
+    const char4    val1_8 = make_char4(
+        table[q1_8[0]], table[q1_8[1]], table[q1_8[2]], table[q1_8[3]]);
+
+    return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
+}
+
 // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
 // MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
 
@@ -211,6 +237,30 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_16_q8_1_
     return d8_1*sumf;
 }
 
+#define VDR_MXFP4_Q8_1_MMVQ 2
+#define VDR_MXFP4_Q8_1_MMQ  4
+
+static __device__ __forceinline__ float vec_dot_mxfp4_q8_1(
+    const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+    const block_mxfp4 * bq4 = (const block_mxfp4 *) vbq + kbx;
+
+    const int * q8 = (const int *) bq8_1->qs + iqs;
+
+    int sumi = 0;
+#pragma unroll
+    for (int l = 0; l < VDR_MXFP4_Q8_1_MMVQ; ++l) {
+        const int aux_q4 = get_int_b1(bq4->qs, iqs + l);
+        const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
+
+        sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
+        sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
+    }
+
+    const float d = ggml_cuda_e8m0_to_fp32(bq4->e) * 0.5f * __low2float(bq8_1->ds);
+    return d * sumi;
+}
+
 #define VDR_Q2_K_Q8_1_MMVQ 1
 #define VDR_Q2_K_Q8_1_MMQ  4
 
@@ -1068,20 +1118,6 @@ static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
     return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1);
 }
 
-static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) {
-    const int      q0_32  = (q4 >> 0) & 0x0F0F0F0F;
-    const int8_t * q0_8   = (const int8_t *) &q0_32;
-    const char4    val0_8 = make_char4(
-        kvalues_iq4nl[q0_8[0]], kvalues_iq4nl[q0_8[1]], kvalues_iq4nl[q0_8[2]], kvalues_iq4nl[q0_8[3]]);
-
-    const int      q1_32  = (q4 >> 4) & 0x0F0F0F0F;
-    const int8_t * q1_8   = (const int8_t *) &q1_32;
-    const char4    val1_8 = make_char4(
-        kvalues_iq4nl[q1_8[0]], kvalues_iq4nl[q1_8[1]], kvalues_iq4nl[q1_8[2]], kvalues_iq4nl[q1_8[3]]);
-
-    return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
-}
-
 #define VDR_IQ4_NL_Q8_1_MMVQ 2
 #define VDR_IQ4_NL_Q8_1_MMQ  4
 
@@ -1096,7 +1132,7 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
 #pragma unroll
     for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
         const int aux_q4 = get_int_b2(bq4->qs, iqs + l);
-        const int2 v = get_int_from_table_16(aux_q4);
+        const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
 
         sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
         sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
@@ -1118,7 +1154,7 @@ static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
 #pragma unroll
     for (int j = 0; j < 4; ++j) {
         const int aux_q4 = get_int_b4(bq4->qs, iqs + j);
-        const int2 v = get_int_from_table_16(aux_q4);
+        const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
 
         const int u0 = get_int_b4(bq8_1[iqs/4].qs, j + 0);
         const int u1 = get_int_b4(bq8_1[iqs/4].qs, j + 4);
index 1746b073203e3e134faadc618474b43efbe36745..3b3086778eed8f6773b9baed97d57fa58d19ddc3 100644 (file)
@@ -6,6 +6,10 @@
 #include <cuda_bf16.h>
 #include <cuda_fp16.h>
 
+#if CUDART_VERSION >= 12050
+#include <cuda_fp8.h>
+#endif // CUDART_VERSION >= 12050
+
 #if CUDART_VERSION < 11020
 #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
 #define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
index a2e30994c4669fbfd60e1230df983dc43b63f913..19a7adb2d101b3ef4ad654c242a20bdecfe56759 100644 (file)
@@ -410,6 +410,67 @@ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
 #define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
 #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
 
+static inline float ggml_e8m0_to_fp32(uint8_t x) {
+    uint32_t bits;  // Stores the raw bit representation of the float
+
+    // Handle special case for minimum exponent (denormalized float)
+    if (x == 0) {
+        // Bit pattern for 2^(-127):
+        // - Sign bit: 0 (positive)
+        // - Exponent: 0 (denormalized number)
+        // - Mantissa: 0x400000 (0.5 in fractional form)
+        // Value = 0.5 * 2^(-126) = 2^(-127)
+        bits = 0x00400000;
+    }
+    // note: disabled as we don't need to handle NaNs
+    //// Handle special case for NaN (all bits set)
+    //else if (x == 0xFF) {
+    //    // Standard quiet NaN pattern:
+    //    // - Sign bit: 0
+    //    // - Exponent: all 1s (0xFF)
+    //    // - Mantissa: 0x400000 (quiet NaN flag)
+    //    bits = 0x7FC00000;
+    //}
+    // Normalized values (most common case)
+    else {
+        // Construct normalized float by shifting exponent into position:
+        // - Exponent field: 8 bits (positions 30-23)
+        // - Mantissa: 0 (implicit leading 1)
+        // Value = 2^(x - 127)
+        bits = (uint32_t) x << 23;
+    }
+
+    float result;  // Final float value
+                   // Safely reinterpret bit pattern as float without type-punning issues
+    memcpy(&result, &bits, sizeof(float));
+    return result;
+}
+
+// Equal to ggml_e8m0_to_fp32/2
+// Useful with MXFP4 quantization since the E0M2 values are doubled
+static inline float ggml_e8m0_to_fp32_half(uint8_t x) {
+    uint32_t bits;
+
+    // For x < 2: use precomputed denormal patterns
+    if (x < 2) {
+        // 0x00200000 = 2^(-128), 0x00400000 = 2^(-127)
+        bits = 0x00200000 << x;
+    }
+    // For x >= 2: normalized exponent adjustment
+    else {
+        // 0.5 * 2^(x-127) = 2^(x-128) = normalized with exponent (x-1)
+        bits = (uint32_t)(x - 1) << 23;
+    }
+    // Note: NaNs are not handled here
+
+    float result;
+    memcpy(&result, &bits, sizeof(float));
+    return result;
+}
+
+#define GGML_E8M0_TO_FP32(x) ggml_e8m0_to_fp32(x)
+#define GGML_E8M0_TO_FP32_HALF(x) ggml_e8m0_to_fp32_half(x)
+
 /**
  * Converts brain16 to float32.
  *
index 8424464d8cadca4926c992a6142d185574b5ea37..fc6526d6d5dc6eec1ea0b31438650fa62d17fc70 100644 (file)
@@ -23,6 +23,9 @@
 #define N_R0_Q8_0 4
 #define N_SG_Q8_0 2
 
+#define N_R0_MXFP4 2
+#define N_SG_MXFP4 2
+
 #define N_R0_Q2_K 4
 #define N_SG_Q2_K 2
 
@@ -129,6 +132,15 @@ typedef struct {
     uint64_t o1[8];
 } ggml_metal_kargs_bin;
 
+typedef struct {
+    int64_t ne0;
+    int64_t ne1;
+    size_t nb01;
+    size_t nb02;
+    size_t nb11;
+    size_t nb21;
+} ggml_metal_kargs_add_id;
+
 typedef struct {
     int32_t  ne00;
     int32_t  ne01;
@@ -444,6 +456,8 @@ typedef struct{
     uint64_t nb1;
     int32_t  i00;
     int32_t  i10;
+    float    alpha;
+    float    limit;
 } ggml_metal_kargs_glu;
 
 typedef struct {
index 337f7985badf328ed928282509725c94ce2e5caa..cb8eff4a772929638616bfe02fee58401ab1df5c 100644 (file)
@@ -195,6 +195,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_ROW_C4,
     GGML_METAL_KERNEL_TYPE_DIV,
     GGML_METAL_KERNEL_TYPE_DIV_ROW_C4,
+    GGML_METAL_KERNEL_TYPE_ADD_ID,
     GGML_METAL_KERNEL_TYPE_REPEAT_F32,
     GGML_METAL_KERNEL_TYPE_REPEAT_F16,
     GGML_METAL_KERNEL_TYPE_REPEAT_I32,
@@ -234,6 +235,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,
+    GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,
@@ -286,6 +288,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
     GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
     GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
@@ -310,6 +313,10 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,
     GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,
     GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5,
     GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2,
     GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3,
     GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4,
@@ -351,6 +358,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,
@@ -373,6 +381,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,
     GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,
@@ -397,6 +406,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16,
+    GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16,
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16,
@@ -579,6 +589,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_REGLU,
     GGML_METAL_KERNEL_TYPE_GEGLU,
     GGML_METAL_KERNEL_TYPE_SWIGLU,
+    GGML_METAL_KERNEL_TYPE_SWIGLU_OAI,
     GGML_METAL_KERNEL_TYPE_GEGLU_ERF,
     GGML_METAL_KERNEL_TYPE_GEGLU_QUICK,
     GGML_METAL_KERNEL_TYPE_SUM_ROWS,
@@ -1199,6 +1210,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW_C4,                      mul_row_c4,                      true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV,                             div,                             true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW_C4,                      div_row_c4,                      true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ID,                          add_id,                          true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32,                      repeat_f32,                      true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16,                      repeat_f16,                      true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32,                      repeat_i32,                      true);
@@ -1238,6 +1250,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,                   get_rows_q5_0,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,                   get_rows_q5_1,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,                   get_rows_q8_0,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4,                  get_rows_mxfp4,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,                   get_rows_q2_K,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,                   get_rows_q3_K,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,                   get_rows_q4_K,                   true);
@@ -1290,6 +1303,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,                 mul_mv_q5_0_f32,                 has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,                 mul_mv_q5_1_f32,                 has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,                 mul_mv_q8_0_f32,                 has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32,                mul_mv_mxfp4_f32,                has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,         mul_mv_ext_f16_f32_r1_2,         has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,         mul_mv_ext_f16_f32_r1_3,         has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,         mul_mv_ext_f16_f32_r1_4,         has_simdgroup_reduction);
@@ -1314,6 +1328,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,        mul_mv_ext_q8_0_f32_r1_3,        has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,        mul_mv_ext_q8_0_f32_r1_4,        has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,        mul_mv_ext_q8_0_f32_r1_5,        has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2,       mul_mv_ext_mxfp4_f32_r1_2,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3,       mul_mv_ext_mxfp4_f32_r1_3,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4,       mul_mv_ext_mxfp4_f32_r1_4,       has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5,       mul_mv_ext_mxfp4_f32_r1_5,       has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2,        mul_mv_ext_q4_K_f32_r1_2,        has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3,        mul_mv_ext_q4_K_f32_r1_3,        has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4,        mul_mv_ext_q4_K_f32_r1_4,        has_simdgroup_reduction);
@@ -1355,6 +1373,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,              mul_mv_id_q5_0_f32,              has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,              mul_mv_id_q5_1_f32,              has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,              mul_mv_id_q8_0_f32,              has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32,             mul_mv_id_mxfp4_f32,             has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,              mul_mv_id_q2_K_f32,              has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,              mul_mv_id_q3_K_f32,              has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,              mul_mv_id_q4_K_f32,              has_simdgroup_reduction);
@@ -1377,6 +1396,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,                 mul_mm_q5_0_f32,                 has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,                 mul_mm_q5_1_f32,                 has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,                 mul_mm_q8_0_f32,                 has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32,                mul_mm_mxfp4_f32,                has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32,                mul_mm_mxfp4_f32,                has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,                 mul_mm_q2_K_f32,                 has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,                 mul_mm_q3_K_f32,                 has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,                 mul_mm_q4_K_f32,                 has_simdgroup_mm);
@@ -1401,6 +1422,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16,              mul_mm_id_q5_0_f16,              has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16,              mul_mm_id_q5_1_f16,              has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16,              mul_mm_id_q8_0_f16,              has_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16,             mul_mm_id_mxfp4_f16,             has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16,              mul_mm_id_q2_K_f16,              has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16,              mul_mm_id_q3_K_f16,              has_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16,              mul_mm_id_q4_K_f16,              has_simdgroup_mm);
@@ -1583,6 +1605,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU,                           reglu,                           true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU,                           geglu,                           true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU,                          swiglu,                          true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU_OAI,                      swiglu_oai,                      true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF,                       geglu_erf,                       true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK,                     geglu_quick,                     true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                        sum_rows,                        true);
@@ -1774,6 +1797,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
                 case GGML_GLU_OP_REGLU:
                 case GGML_GLU_OP_GEGLU:
                 case GGML_GLU_OP_SWIGLU:
+                case GGML_GLU_OP_SWIGLU_OAI:
                 case GGML_GLU_OP_GEGLU_ERF:
                 case GGML_GLU_OP_GEGLU_QUICK:
                     return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
@@ -1791,6 +1815,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
         case GGML_OP_SUB:
         case GGML_OP_MUL:
         case GGML_OP_DIV:
+        case GGML_OP_ADD_ID:
             return op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_ACC:
         case GGML_OP_REPEAT:
@@ -2042,6 +2067,7 @@ static int ggml_metal_encode_node(
 
     const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
     const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
+    const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT;
     const enum ggml_type dstt  = dst  ? dst->type  : GGML_TYPE_COUNT;
 
     size_t offs_src0 = 0;
@@ -2291,6 +2317,38 @@ static int ggml_metal_encode_node(
                     [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                 }
             } break;
+        case GGML_OP_ADD_ID:
+            {
+                GGML_ASSERT(src0t == GGML_TYPE_F32);
+                GGML_ASSERT(src1t == GGML_TYPE_F32);
+                GGML_ASSERT(src2t == GGML_TYPE_I32);
+                GGML_ASSERT(dstt  == GGML_TYPE_F32);
+
+                GGML_ASSERT(ggml_is_contiguous_rows(src0));
+
+                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ID].pipeline;
+
+                ggml_metal_kargs_add_id args = {
+                    /*.ne0  =*/ ne0,
+                    /*.ne1  =*/ ne1,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.nb11 =*/ nb11,
+                    /*.nb21 =*/ nb21,
+
+                };
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBytes:&args length:sizeof(args) atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+                [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
+                [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
+                [encoder setBuffer:id_dst  offset:offs_dst  atIndex:4];
+
+                const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
+
+                [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
         case GGML_OP_REPEAT:
             {
                 id<MTLComputePipelineState> pipeline;
@@ -2710,6 +2768,9 @@ static int ggml_metal_encode_node(
                     case GGML_GLU_OP_SWIGLU:
                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
                         break;
+                    case GGML_GLU_OP_SWIGLU_OAI:
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU_OAI].pipeline;
+                        break;
                     case GGML_GLU_OP_GEGLU_ERF:
                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_ERF].pipeline;
                         break;
@@ -2720,7 +2781,9 @@ static int ggml_metal_encode_node(
                         GGML_ABORT("fatal error");
                 }
 
-                const int32_t swp = ((const int32_t *) dst->op_params)[1];
+                const int32_t swp = ggml_get_op_params_i32(dst, 1);
+                const float alpha = ggml_get_op_params_f32(dst, 2);
+                const float limit = ggml_get_op_params_f32(dst, 3);
 
                 const int32_t i00 = swp ? ne0 : 0;
                 const int32_t i10 = swp ? 0 : ne0;
@@ -2734,6 +2797,8 @@ static int ggml_metal_encode_node(
                     /*.nb1  =*/ nb1,
                     /*.i00  =*/ src1 ? 0 : i00,
                     /*.i10  =*/ src1 ? 0 : i10,
+                    /*.alpha=*/ alpha,
+                    /*.limit=*/ limit
                 };
 
                 [encoder setComputePipelineState:pipeline];
@@ -2992,8 +3057,13 @@ static int ggml_metal_encode_node(
                 } else {
                     [encoder setBuffer:h_src0 offset:offs_src0  atIndex:1];
                 }
-                [encoder setBuffer:id_dst offset:offs_dst       atIndex:2];
-                [encoder setBytes:&args   length:sizeof(args)   atIndex:3];
+                if (id_src2) {
+                    [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
+                } else {
+                    [encoder setBuffer:h_src0 offset:offs_src0  atIndex:2];
+                }
+                [encoder setBuffer:id_dst offset:offs_dst       atIndex:3];
+                [encoder setBytes:&args   length:sizeof(args)   atIndex:4];
 
                 [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
 
@@ -3291,6 +3361,7 @@ static int ggml_metal_encode_node(
                        src0t == GGML_TYPE_Q5_0 ||
                        src0t == GGML_TYPE_Q5_1 ||
                        src0t == GGML_TYPE_Q8_0 ||
+                       src0t == GGML_TYPE_MXFP4 ||
                        src0t == GGML_TYPE_IQ4_NL ||
                        false) && (ne11 >= 2 && ne11 <= 8)
                      ) ||
@@ -3383,6 +3454,14 @@ static int ggml_metal_encode_node(
                                 case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break;
                                 default: GGML_ABORT("not implemented");
                             } break;
+                        case GGML_TYPE_MXFP4:
+                            switch (r1ptg) {
+                                case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2].pipeline; break;
+                                case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3].pipeline; break;
+                                case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4].pipeline; break;
+                                case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5].pipeline; break;
+                                default: GGML_ABORT("not implemented");
+                            } break;
                         case GGML_TYPE_Q4_K:
                             switch (r1ptg) {
                                 case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2].pipeline; break;
@@ -3481,6 +3560,7 @@ static int ggml_metal_encode_node(
                         case GGML_TYPE_Q5_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32   ].pipeline; break;
                         case GGML_TYPE_Q5_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32   ].pipeline; break;
                         case GGML_TYPE_Q8_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32   ].pipeline; break;
+                        case GGML_TYPE_MXFP4:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32   ].pipeline; break;
                         case GGML_TYPE_Q2_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32   ].pipeline; break;
                         case GGML_TYPE_Q3_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32   ].pipeline; break;
                         case GGML_TYPE_Q4_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32   ].pipeline; break;
@@ -3623,6 +3703,13 @@ static int ggml_metal_encode_node(
                                 nr0 = N_R0_Q8_0;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
                             } break;
+                        case GGML_TYPE_MXFP4:
+                            {
+                                nsg = N_SG_MXFP4;
+                                nr0 = N_R0_MXFP4;
+                                smem = 32*sizeof(float);
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32].pipeline;
+                            } break;
                         case GGML_TYPE_Q2_K:
                             {
                                 nsg = N_SG_Q2_K;
@@ -3756,8 +3843,6 @@ static int ggml_metal_encode_node(
         case GGML_OP_MUL_MAT_ID:
             {
                 // src2 = ids
-                const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
-
                 GGML_ASSERT(src2t == GGML_TYPE_I32);
 
                 GGML_ASSERT(!ggml_is_transposed(src0));
@@ -3883,6 +3968,7 @@ static int ggml_metal_encode_node(
                             case GGML_TYPE_Q5_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16   ].pipeline; break;
                             case GGML_TYPE_Q5_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16   ].pipeline; break;
                             case GGML_TYPE_Q8_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16   ].pipeline; break;
+                            case GGML_TYPE_MXFP4:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16  ].pipeline; break;
                             case GGML_TYPE_Q2_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16   ].pipeline; break;
                             case GGML_TYPE_Q3_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16   ].pipeline; break;
                             case GGML_TYPE_Q4_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16   ].pipeline; break;
@@ -4018,6 +4104,13 @@ static int ggml_metal_encode_node(
                                 nr0 = N_R0_Q8_0;
                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
                             } break;
+                        case GGML_TYPE_MXFP4:
+                            {
+                                nsg = N_SG_MXFP4;
+                                nr0 = N_R0_MXFP4;
+                                smem = 32*sizeof(float);
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32].pipeline;
+                            } break;
                         case GGML_TYPE_Q2_K:
                             {
                                 nsg = N_SG_Q2_K;
@@ -4170,6 +4263,7 @@ static int ggml_metal_encode_node(
                     case GGML_TYPE_Q5_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0   ].pipeline; break;
                     case GGML_TYPE_Q5_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1   ].pipeline; break;
                     case GGML_TYPE_Q8_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0   ].pipeline; break;
+                    case GGML_TYPE_MXFP4:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4  ].pipeline; break;
                     case GGML_TYPE_Q2_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K   ].pipeline; break;
                     case GGML_TYPE_Q3_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K   ].pipeline; break;
                     case GGML_TYPE_Q4_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K   ].pipeline; break;
@@ -4980,11 +5074,14 @@ static int ggml_metal_encode_node(
                 GGML_ASSERT(ne11 == ne21);
                 GGML_ASSERT(ne12 == ne22);
 
-                struct ggml_tensor * src3 = node->src[3];
+                struct ggml_tensor * src3 = node->src[3]; // mask
+                struct ggml_tensor * src4 = node->src[4]; // sinks
 
                 size_t offs_src3 = 0;
+                size_t offs_src4 = 0;
 
                 id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
+                id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
 
                 GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16);
                 GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
@@ -5000,8 +5097,6 @@ static int ggml_metal_encode_node(
                 const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32);
                 const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33);
 
-                const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
-
                 float scale;
                 float max_bias;
                 float logit_softcap;
@@ -5389,7 +5484,12 @@ static int ggml_metal_encode_node(
                 } else {
                     [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4];
                 }
-                [encoder setBuffer:id_dst offset:offs_dst       atIndex:5];
+                if (id_src4) {
+                    [encoder setBuffer:id_src4 offset:offs_src4 atIndex:5];
+                } else {
+                    [encoder setBuffer:id_src0 offset:offs_src0 atIndex:5];
+                }
+                [encoder setBuffer:id_dst offset:offs_dst       atIndex:6];
 
                 if (!use_vec_kernel) {
                     // half8x8 kernel
index 99a453090f6b0ab58636a2ba53378555342d69ca..b35a3bbdc317fa239e01712fe57a0f420a51f318 100644 (file)
@@ -35,6 +35,10 @@ constexpr constant static float kvalues_iq4nl_f[16] = {
     -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
 };
 
+constexpr constant static float kvalues_mxfp4_f[16] = {
+    0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f
+};
+
 static inline int best_index_int8(int n, constant float * val, float x) {
     if (x <= val[0]) return 0;
     if (x >= val[n-1]) return n-1;
@@ -46,6 +50,18 @@ static inline int best_index_int8(int n, constant float * val, float x) {
     return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
 }
 
+static inline float e8m0_to_fp32(uint8_t x) {
+    uint32_t bits;
+
+    if (x == 0) {
+        bits = 0x00400000;
+    } else {
+        bits = (uint32_t) x << 23;
+    }
+
+    return as_type<float>(bits);
+}
+
 // NOTE: this is not dequantizing - we are simply fitting the template
 template <typename type4x4>
 void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
@@ -242,6 +258,27 @@ void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
     }
 }
 
+void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
+#pragma METAL fp math_mode(safe)
+    float amax = 0.0f; // absolute max
+
+    for (int j = 0; j < QK8_0; j++) {
+        const float v = src[j];
+        amax = MAX(amax, fabs(v));
+    }
+
+    const float d = amax / ((1 << 7) - 1);
+    const float id = d ? 1.0f/d : 0.0f;
+
+    dst.d = d;
+
+    for (int j = 0; j < QK8_0; ++j) {
+        const float x0 = src[j]*id;
+
+        dst.qs[j] = round(x0);
+    }
+}
+
 void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
 #pragma METAL fp math_mode(safe)
     float amax = 0.0f; // absolute max
@@ -462,25 +499,34 @@ void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & re
     }
 }
 
-void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
-#pragma METAL fp math_mode(safe)
-    float amax = 0.0f; // absolute max
+template <typename type4x4>
+void dequantize_mxfp4(device const block_mxfp4 * xb, short il, thread type4x4 & reg) {
+    device const uint8_t * q2 = (device const uint8_t *)xb->qs;
 
-    for (int j = 0; j < QK8_0; j++) {
-        const float v = src[j];
-        amax = MAX(amax, fabs(v));
+    const float d = e8m0_to_fp32(xb->e);
+    const uint8_t shr = il >= 1 ? 4 : 0;
+
+    for (int i = 0; i < 4; ++i) {
+        reg[i][0] = d * kvalues_mxfp4_f[(q2[4*i + 0] >> shr) & 0x0F];
+        reg[i][1] = d * kvalues_mxfp4_f[(q2[4*i + 1] >> shr) & 0x0F];
+        reg[i][2] = d * kvalues_mxfp4_f[(q2[4*i + 2] >> shr) & 0x0F];
+        reg[i][3] = d * kvalues_mxfp4_f[(q2[4*i + 3] >> shr) & 0x0F];
     }
+}
 
-    const float d = amax / ((1 << 7) - 1);
-    const float id = d ? 1.0f/d : 0.0f;
+template <typename type4>
+void dequantize_mxfp4_t4(device const block_mxfp4 * xb, short il, thread type4 & reg) {
+    device const uint8_t * q2 = (device const uint8_t *)xb->qs;
 
-    dst.d = d;
+    const float d = e8m0_to_fp32(xb->e);
+    const short il4 = il%4;
 
-    for (int j = 0; j < QK8_0; ++j) {
-        const float x0 = src[j]*id;
+    const uint8_t shr = il >= 4 ? 4 : 0;
 
-        dst.qs[j] = round(x0);
-    }
+    reg[0] = d * kvalues_mxfp4_f[(q2[4*il4 + 0] >> shr) & 0x0F];
+    reg[1] = d * kvalues_mxfp4_f[(q2[4*il4 + 1] >> shr) & 0x0F];
+    reg[2] = d * kvalues_mxfp4_f[(q2[4*il4 + 2] >> shr) & 0x0F];
+    reg[3] = d * kvalues_mxfp4_f[(q2[4*il4 + 3] >> shr) & 0x0F];
 }
 
 template <typename type4x4>
@@ -960,6 +1006,32 @@ kernel void kernel_div(
     }
 }
 
+kernel void kernel_add_id(
+        constant ggml_metal_kargs_add_id & args,
+        device const char * src0,
+        device const char * src1,
+        device const char * src2,
+        device       char * dst,
+        uint3   tgpig[[threadgroup_position_in_grid]],
+        ushort3 tpitg[[thread_position_in_threadgroup]],
+        ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i1 = tgpig.x;
+    const int i2 = tgpig.y;
+
+    const int i11 = *((device const int32_t *) (src2 + i1*sizeof(int32_t) + i2*args.nb21));
+
+    const size_t nb1 = args.ne0 * sizeof(float);
+    const size_t nb2 = args.ne1 * nb1;
+
+    device       float * dst_row  = (device       float *)((device char *)dst + i1*nb1 + i2*nb2);
+    device const float * src0_row = (device const float *)((device char *)src0 +  i1*args.nb01 + i2*args.nb02);
+    device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);
+
+    for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+        dst_row[i0] = src0_row[i0] + src1_row[i0];
+    }
+}
+
 template<typename T>
 kernel void kernel_repeat(
         constant ggml_metal_kargs_repeat & args,
@@ -1431,6 +1503,32 @@ kernel void kernel_swiglu(
     }
 }
 
+kernel void kernel_swiglu_oai(
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        constant ggml_metal_kargs_glu & args,
+        uint tgpig[[threadgroup_position_in_grid]],
+        uint tpitg[[thread_position_in_threadgroup]],
+        uint   ntg[[threads_per_threadgroup]]) {
+    device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
+    device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
+    device       float * dst_row  = (device       float *) ((device       char *) dst  + tgpig*args.nb1);
+
+    for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
+        float x0 = src0_row[i0];
+        float x1 = src1_row[i0];
+
+        x0 = min(x0, args.limit);
+        x1 = max(min(x1, args.limit), -args.limit);
+
+        float out_glu = x0 / (1.0f + exp(-x0 * args.alpha));
+        out_glu = out_glu * (1.0f + x1);
+
+        dst_row[i0] = out_glu;
+    }
+}
+
 kernel void kernel_geglu_erf(
         device const char * src0,
         device const char * src1,
@@ -1534,6 +1632,7 @@ template<typename T>
 kernel void kernel_soft_max(
         device const  char * src0,
         device const  char * src1,
+        device const  char * src2,
         device        char * dst,
         constant ggml_metal_kargs_soft_max & args,
         threadgroup  float * buf [[threadgroup(0)]],
@@ -1552,6 +1651,7 @@ kernel void kernel_soft_max(
 
     device const float * psrc0 =                (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
     device const     T * pmask = src1 != src0 ? (device const T *    ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
+    device const float * psrc2 = src2 != src0 ? (device const float *) (src2)                                                 : nullptr;
     device       float * pdst  =                (device       float *) (dst  + i01*args.nb1  + i02*args.nb2  + i03*args.nb3);
 
     float slope = 1.0f;
@@ -1567,7 +1667,7 @@ kernel void kernel_soft_max(
     }
 
     // parallel max
-    float lmax = -INFINITY;
+    float lmax = psrc2 ? psrc2[i02] : -INFINITY;
 
     for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
         lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
@@ -1623,6 +1723,10 @@ kernel void kernel_soft_max(
         sum = simd_sum(sum);
     }
 
+    if (psrc2) {
+        sum += exp(psrc2[i02] - max_val);
+    }
+
     const float inv_sum = 1.0f/sum;
 
     for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
@@ -1634,6 +1738,7 @@ template<typename T>
 kernel void kernel_soft_max_4(
         device const  char * src0,
         device const  char * src1,
+        device const  char * src2,
         device        char * dst,
         constant ggml_metal_kargs_soft_max & args,
         threadgroup  float * buf [[threadgroup(0)]],
@@ -1652,6 +1757,7 @@ kernel void kernel_soft_max_4(
 
     device const float4 * psrc4 =                (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
     device const      T * pmask = src1 != src0 ? (device const T *     ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
+    device const float *  psrc2 = src2 != src0 ? (device const float * ) (src2)                                                 : nullptr;
     device       float4 * pdst4 =                (device       float4 *) (dst  + i01*args.nb1  + i02*args.nb2  + i03*args.nb3);
 
     float slope = 1.0f;
@@ -1666,7 +1772,7 @@ kernel void kernel_soft_max_4(
     }
 
     // parallel max
-    float4 lmax4 = -INFINITY;
+    float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
 
     for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
         lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
@@ -1725,6 +1831,10 @@ kernel void kernel_soft_max_4(
         sum = simd_sum(sum);
     }
 
+    if (psrc2) {
+        sum += exp(psrc2[i02] - max_val);
+    }
+
     const float inv_sum = 1.0f/sum;
 
     for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
@@ -3106,6 +3216,11 @@ template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_3")]]   kernel mul_mv_ext_q4
 template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_4")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q8_0,   32, dequantize_q8_0_t4>;
 template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]]   kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q8_0,   32, dequantize_q8_0_t4>;
 
+template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_2")]]  kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_mxfp4,  32, dequantize_mxfp4_t4>;
+template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_3")]]  kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_mxfp4,  32, dequantize_mxfp4_t4>;
+template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_4")]]  kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_mxfp4,  32, dequantize_mxfp4_t4>;
+template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_5")]]  kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_mxfp4,  32, dequantize_mxfp4_t4>;
+
 template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
 template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
 template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
@@ -4092,6 +4207,7 @@ kernel void kernel_flash_attn_ext(
         device const char * k,
         device const char * v,
         device const char * mask,
+        device const char * sinks,
         device       char * dst,
         threadgroup  half * shmem_f16 [[threadgroup(0)]],
         uint3   tgpig[[threadgroup_position_in_grid]],
@@ -4407,6 +4523,35 @@ kernel void kernel_flash_attn_ext(
             }
         }
 
+        if (sinks != q && sgitg == 0) {
+            for (ushort j = 0; j < Q; ++j) {
+                const float m = M[j];
+                const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
+
+                M[j] = simd_max(max(M[j], s));
+
+                const float ms = exp(m - M[j]);
+                const float vs = exp(s - M[j]);
+
+                S[j] = S[j]*ms + simd_sum(vs);
+
+                if (tiisg == j) {
+                    ss[j*TS + 2*C + j] = ms;
+                }
+            }
+
+            // O = diag(ms)*O
+            {
+                s8x8_t ms;
+                simdgroup_load(ms, ss + 2*C, TS, 0, false);
+
+                #pragma unroll(DV8)
+                for (short i = 0; i < DV8; ++i) {
+                    simdgroup_multiply(lo[i], ms, lo[i]);
+                }
+            }
+        }
+
         // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
         for (short j = tiisg; j < Q; j += NW) {
             ss[j*TS + 0] = S[j];
@@ -4618,6 +4763,7 @@ kernel void kernel_flash_attn_ext_vec(
         device const char * k,
         device const char * v,
         device const char * mask,
+        device const char * sinks,
         device       char * dst,
         threadgroup  half * shmem_f16 [[threadgroup(0)]],
         uint3   tgpig[[threadgroup_position_in_grid]],
@@ -4835,6 +4981,23 @@ kernel void kernel_flash_attn_ext_vec(
             }
         }
 
+        if (sinks != q && sgitg == 0) {
+            const float m = M;
+            const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
+
+            M = simd_max(max(M, s));
+
+            const float ms = exp(m - M);
+            const float vs = exp(s - M);
+
+            S = S*ms + simd_sum(vs);
+
+#pragma unroll(DV4/NL)
+            for (short ii = 0; ii < DV4; ii += NL) {
+                lo[ii/NL] *= ms;
+            }
+        }
+
         // these are needed for reducing the results from the simdgroups (reuse the ss buffer)
         if (tiisg == 0) {
             ss[0] = (s_t) S;
@@ -6940,6 +7103,95 @@ kernel void kernel_mul_mv_iq4_xs_f32(
     kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, N_SG_IQ4_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
 }
 
+template<int nr0, int nsg, int nw, typename args_t>
+void kernel_mul_mv_mxfp4_f32_impl(
+        args_t args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem,
+        uint3  tgpig,
+        ushort tiisg,
+        ushort sgitg) {
+
+    threadgroup float * shmem_f32 = (threadgroup float *) shmem;
+    const int nb = args.ne00/QK_MXFP4;
+
+    const int r0 = tgpig.x;
+    const int r1 = tgpig.y;
+    const int im = tgpig.z;
+
+    const int first_row = (r0 * nsg + sgitg) * nr0;
+
+    const uint i12 = im%args.ne12;
+    const uint i13 = im/args.ne12;
+
+    const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+    const uint64_t offset1 =        r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13;
+
+    device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0);
+    device const float       * y = (device const float       *) (src1 + offset1);
+
+    const short ix = tiisg/2;  // 0...15
+    const short it = tiisg%2;  // 0 or 1
+
+    shmem_f32[tiisg] = kvalues_mxfp4_f[tiisg%16];
+    threadgroup_barrier(mem_flags::mem_threadgroup);
+
+    float4 yl[4];
+    float sumf[nr0]={0.f};
+
+    device const float * yb = y + ix * QK_MXFP4 + it * 8;
+
+    for (int ib = ix; ib < nb; ib += 16) {
+        device const float4 * y4 = (device const float4 *)yb;
+        yl[0] = y4[0];
+        yl[1] = y4[4];
+        yl[2] = y4[1];
+        yl[3] = y4[5];
+
+#pragma unroll(nr0)
+        for (short row = 0; row < nr0; row++) {
+            device const block_mxfp4 & xb = x[row*nb + ib];
+            device const uint8_t     * q2 = (device const uint8_t *)(xb.qs + 8*it);
+
+            float4 acc1 = yl[0]*float4(shmem_f32[q2[0] &  0x0F], shmem_f32[q2[1] &  0x0F], shmem_f32[q2[2] &  0x0F], shmem_f32[q2[3] &  0x0F]);
+            float4 acc2 = yl[1]*float4(shmem_f32[q2[0] >> 4   ], shmem_f32[q2[1] >> 4   ], shmem_f32[q2[2] >> 4   ], shmem_f32[q2[3] >> 4   ]);
+            float4 acc3 = yl[2]*float4(shmem_f32[q2[4] &  0x0F], shmem_f32[q2[5] &  0x0F], shmem_f32[q2[6] &  0x0F], shmem_f32[q2[7] &  0x0F]);
+            float4 acc4 = yl[3]*float4(shmem_f32[q2[4] >> 4   ], shmem_f32[q2[5] >> 4   ], shmem_f32[q2[6] >> 4   ], shmem_f32[q2[7] >> 4   ]);
+
+            acc1 = (acc1 + acc3) + (acc2 + acc4);
+
+            sumf[row] += e8m0_to_fp32(xb.e) * ((acc1[0] + acc1[1]) + (acc1[2] + acc1[3]));
+        }
+
+        yb += 16 * QK_MXFP4;
+    }
+
+    device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+    for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+        float sum_all = simd_sum(sumf[row]);
+        if (tiisg == 0) {
+            dst_f32[first_row + row] = sum_all;
+        }
+    }
+}
+
+[[host_name("kernel_mul_mv_mxfp4_f32")]]
+kernel void kernel_mul_mv_mxfp4_f32(
+        constant ggml_metal_kargs_mul_mv & args,
+        device const char * src0,
+        device const char * src1,
+        device       char * dst,
+        threadgroup  char * shmem [[threadgroup(0)]],
+        uint3  tgpig[[threadgroup_position_in_grid]],
+        ushort tiisg[[thread_index_in_simdgroup]],
+        ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, N_SG_MXFP4, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+}
+
 template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
 kernel void kernel_get_rows_q(
         constant ggml_metal_kargs_get_rows & args,
@@ -7475,6 +7727,7 @@ template [[host_name("kernel_get_rows_q4_1")]]    kernel get_rows_q_t kernel_get
 template [[host_name("kernel_get_rows_q5_0")]]    kernel get_rows_q_t kernel_get_rows_q<block_q5_0,    2, dequantize_q5_0>;
 template [[host_name("kernel_get_rows_q5_1")]]    kernel get_rows_q_t kernel_get_rows_q<block_q5_1,    2, dequantize_q5_1>;
 template [[host_name("kernel_get_rows_q8_0")]]    kernel get_rows_q_t kernel_get_rows_q<block_q8_0,    2, dequantize_q8_0>;
+template [[host_name("kernel_get_rows_mxfp4")]]   kernel get_rows_q_t kernel_get_rows_q<block_mxfp4,   2, dequantize_mxfp4>;
 template [[host_name("kernel_get_rows_q2_K")]]    kernel get_rows_q_t kernel_get_rows_q<block_q2_K,    QK_NL, dequantize_q2_K>;
 template [[host_name("kernel_get_rows_q3_K")]]    kernel get_rows_q_t kernel_get_rows_q<block_q3_K,    QK_NL, dequantize_q3_K>;
 template [[host_name("kernel_get_rows_q4_K")]]    kernel get_rows_q_t kernel_get_rows_q<block_q4_K,    QK_NL, dequantize_q4_K>;
@@ -7527,6 +7780,7 @@ template [[host_name("kernel_mul_mm_q4_1_f32")]]    kernel mul_mm_t kernel_mul_m
 template [[host_name("kernel_mul_mm_q5_0_f32")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q5_0,    2,     dequantize_q5_0>;
 template [[host_name("kernel_mul_mm_q5_1_f32")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q5_1,    2,     dequantize_q5_1>;
 template [[host_name("kernel_mul_mm_q8_0_f32")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q8_0,    2,     dequantize_q8_0>;
+template [[host_name("kernel_mul_mm_mxfp4_f32")]]   kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_mxfp4,   2,     dequantize_mxfp4>;
 template [[host_name("kernel_mul_mm_q2_K_f32")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q2_K,    QK_NL, dequantize_q2_K>;
 template [[host_name("kernel_mul_mm_q3_K_f32")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q3_K,    QK_NL, dequantize_q3_K>;
 template [[host_name("kernel_mul_mm_q4_K_f32")]]    kernel mul_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q4_K,    QK_NL, dequantize_q4_K>;
@@ -7558,6 +7812,7 @@ template [[host_name("kernel_mul_mm_id_q4_1_f16")]]    kernel mul_mm_id kernel_m
 template [[host_name("kernel_mul_mm_id_q5_0_f16")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   block_q5_0,    2,     dequantize_q5_0>;
 template [[host_name("kernel_mul_mm_id_q5_1_f16")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   block_q5_1,    2,     dequantize_q5_1>;
 template [[host_name("kernel_mul_mm_id_q8_0_f16")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   block_q8_0,    2,     dequantize_q8_0>;
+template [[host_name("kernel_mul_mm_id_mxfp4_f16")]]   kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   block_mxfp4,   2,     dequantize_mxfp4>;
 template [[host_name("kernel_mul_mm_id_q2_K_f16")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   block_q2_K,    QK_NL, dequantize_q2_K>;
 template [[host_name("kernel_mul_mm_id_q3_K_f16")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   block_q3_K,    QK_NL, dequantize_q3_K>;
 template [[host_name("kernel_mul_mm_id_q4_K_f16")]]    kernel mul_mm_id kernel_mul_mm_id<half,   half4x4,   simdgroup_half8x8,   block_q4_K,    QK_NL, dequantize_q4_K>;
@@ -7703,6 +7958,8 @@ template [[host_name("kernel_mul_mv_id_q4_1_f32")]]    kernel kernel_mul_mv_id_t
 template [[host_name("kernel_mul_mv_id_q5_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SG_Q5_0, N_SIMDWIDTH>>>;
 template [[host_name("kernel_mul_mv_id_q5_1_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SG_Q5_1, N_SIMDWIDTH>>>;
 
+template [[host_name("kernel_mul_mv_id_mxfp4_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, N_SG_MXFP4, N_SIMDWIDTH>>>;
+
 template [[host_name("kernel_mul_mv_id_q2_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl   <N_R0_Q2_K,    N_SG_Q2_K,    N_SIMDWIDTH>>>;
 template [[host_name("kernel_mul_mv_id_q3_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl   <N_R0_Q3_K,    N_SG_Q3_K,    N_SIMDWIDTH>>>;
 template [[host_name("kernel_mul_mv_id_q4_K_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl   <N_R0_Q4_K,    N_SG_Q4_K,    N_SIMDWIDTH>>>;
index c9316eb7fd39bc3a8b06badde012f648cbea1ad6..bb8b310b983e270135b5efd0061bc18442c0def3 100644 (file)
@@ -2497,6 +2497,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
         case GGML_OP_CLAMP:
             return op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_SOFT_MAX:
+            // TODO: support attention sinks [TAG_ATTN_SINKS]
+            return op->src[2] == nullptr;
         case GGML_OP_NORM:
         case GGML_OP_RMS_NORM:
             return true;
index 9a7d1b22d7983c82b3918a2442ca6a59621c8508..a57d2a16d6c540e8baca40138525d4813246a57b 100644 (file)
 
 #define UNUSED GGML_UNUSED
 
+static inline int best_index_int8(int n, const int8_t * val, float x) {
+    if (x <= val[0]) return 0;
+    if (x >= val[n-1]) return n-1;
+    int ml = 0, mu = n-1;
+    while (mu-ml > 1) {
+        int mav = (ml+mu)/2;
+        if (x < val[mav]) mu = mav; else ml = mav;
+    }
+    return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
+}
+
 // reference implementation for deterministic creation of model files
 void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k) {
     static const int qk = QK4_0;
@@ -246,6 +257,53 @@ void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_REST
     }
 }
 
+static inline int best_index_mxfp4(float x, float e) {
+    int best_index = 0;
+    float best_err = fabsf(kvalues_mxfp4[0]*e - x);
+    for (int i = 1; i < 16; i++) {
+        float err = fabsf(kvalues_mxfp4[i]*e - x);
+        if (err < best_err) {
+            best_index = i;
+            best_err = err;
+        }
+    }
+    return best_index;
+}
+
+void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k) {
+    static const int qk = QK_MXFP4;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        float amax = 0.0f; // absolute max
+
+        for (int j = 0; j < qk; j++) {
+            const float v = x[i*qk + j];
+
+            if (amax < fabsf(v)) {
+                amax = fabsf(v);
+            }
+        }
+
+        const uint8_t e = (uint8_t) (floorf(log2f(amax)) - 2 + 127);
+
+        const float d = GGML_E8M0_TO_FP32_HALF(e);
+
+        y[i].e = e;
+
+        for (int j = 0; j < qk/2; ++j) {
+            const uint8_t x0 = best_index_mxfp4(x[i*qk + 0    + j], d);
+            const uint8_t x1 = best_index_mxfp4(x[i*qk + qk/2 + j], d);
+
+            y[i].qs[j]  = x0;
+            y[i].qs[j] |= x1 << 4;
+        }
+    }
+}
+
 void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
     static const int qk = QK4_0;
 
@@ -356,6 +414,26 @@ void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRI
     }
 }
 
+void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
+    static const int qk = QK_MXFP4;
+
+    assert(k % qk == 0);
+
+    const int nb = k / qk;
+
+    for (int i = 0; i < nb; i++) {
+        const float d = GGML_E8M0_TO_FP32_HALF(x[i].e);
+
+        for (int j = 0; j < qk/2; ++j) {
+            const int8_t x0 = kvalues_mxfp4[x[i].qs[j] & 0x0F];
+            const int8_t x1 = kvalues_mxfp4[x[i].qs[j] >>   4];
+
+            y[i*qk + j + 0   ] = x0*d;
+            y[i*qk + j + qk/2] = x1*d;
+        }
+    }
+}
+
 //
 // 2-6 bit quantization in super-blocks
 //
@@ -2014,6 +2092,12 @@ size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
     return nrow * row_size;
 }
 
+size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+    GGML_UNUSED(quant_weights);
+    quantize_row_mxfp4_ref(src, dst, (int64_t)nrow*n_per_row);
+    return nrow * ggml_row_size(GGML_TYPE_MXFP4, n_per_row);
+}
+
 // ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
 
 void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k) {
@@ -4551,17 +4635,6 @@ size_t quantize_iq1_m(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
 
 // ============================ 4-bit non-linear quants
 
-static inline int best_index_int8(int n, const int8_t * val, float x) {
-    if (x <= val[0]) return 0;
-    if (x >= val[n-1]) return n-1;
-    int ml = 0, mu = n-1;
-    while (mu-ml > 1) {
-        int mav = (ml+mu)/2;
-        if (x < val[mav]) mu = mav; else ml = mav;
-    }
-    return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
-}
-
 static void quantize_row_iq4_nl_impl(const int super_block_size, const int block_size, const float * GGML_RESTRICT x,
         ggml_fp16_t * dh, uint8_t * q4, uint16_t * scales_h, uint8_t * scales_l,
         float * scales, float * weight, uint8_t * L,
@@ -4961,6 +5034,15 @@ static bool validate_fp16(ggml_fp16_t f, size_t i) {
     return true;
 }
 
+static bool validate_e_e8m0(uint8_t e, size_t i) {
+    if (e == 0xff) {
+        fprintf(stderr, "ggml_validate_row_data: found invalid e value %d at block %zu\n", e, i);
+        return false;
+    }
+
+    return true;
+}
+
 #define VALIDATE_ROW_DATA_D_F16_IMPL(type, data, nb) \
     const type * q = (const type *) (data); \
     for (size_t i = 0; i < (nb); ++i) { \
@@ -4977,6 +5059,14 @@ static bool validate_fp16(ggml_fp16_t f, size_t i) {
         } \
     }
 
+#define VALIDATE_ROW_DATA_E_E8M0_IMPL(type, data, nb) \
+    const type * q = (const type *) (data); \
+    for (size_t i = 0; i < (nb); ++i) { \
+        if (!validate_e_e8m0(q[i].e, i)) { \
+            return false; \
+        } \
+    }
+
 #define VALIDATE_ROW_DATA_DVEC_F16_IMPL(type, data, nb, nr) \
     const type * q = (const type *) (data); \
     for (size_t i = 0; i < (nb); ++i) { \
@@ -5130,6 +5220,10 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
             {
                 VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb);
             } break;
+        case GGML_TYPE_MXFP4:
+            {
+                VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp4, data, nb);
+            } break;
         case GGML_TYPE_Q2_K:
             {
                 VALIDATE_ROW_DATA_DM_F16_IMPL(block_q2_K, data, nb, d, dmin);
index d09173e11161aa954f4b3258f35ab0ef2a3bb5f7..3b688f31c21459bdabe7094f067bd05998d1e37e 100644 (file)
@@ -21,6 +21,8 @@ GGML_API void quantize_row_q5_1_ref(const float * GGML_RESTRICT x, block_q5_1 *
 GGML_API void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k);
 GGML_API void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k);
 
+GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k);
+
 GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k);
 GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k);
 GGML_API void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k);
@@ -45,6 +47,8 @@ GGML_API void dequantize_row_q5_1(const block_q5_1 * GGML_RESTRICT x, float * GG
 GGML_API void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
 //GGML_API void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
 
+GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+
 GGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
 GGML_API void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
 GGML_API void dequantize_row_q4_K(const block_q4_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
@@ -90,6 +94,8 @@ GGML_API size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTR
 GGML_API size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
 GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
 
+GGML_API size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+
 GGML_API void iq2xs_init_impl(enum ggml_type type);
 GGML_API void iq2xs_free_impl(enum ggml_type type);
 GGML_API void iq3xs_init_impl(int grid_size);
index f8cdb03c3d4c5b3296a85d4b9f34b8b438f1bab9..6fa27418ce83d3d41c6109bc3cac2492dbb5bfdf 100644 (file)
@@ -4193,15 +4193,9 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_MUL_MAT:
         case GGML_OP_MUL_MAT_ID:
             {
-                struct ggml_tensor * a;
-                struct ggml_tensor * b;
-                if (op->op == GGML_OP_MUL_MAT) {
-                    a = op->src[0];
-                    b = op->src[1];
-                } else {
-                    a = op->src[2];
-                    b = op->src[1];
-                }
+                struct ggml_tensor * a = op->src[0];
+                struct ggml_tensor * b = op->src[1];
+
                 if (a->ne[3] != b->ne[3]) {
                     return false;
                 }
@@ -4216,7 +4210,9 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
                     }
                 }
                 ggml_type src0_type = op->src[0]->type;
-                if (src0_type == GGML_TYPE_BF16) {
+                if (src0_type == GGML_TYPE_BF16 || src0_type == GGML_TYPE_MXFP4) {
+                    // TODO: support MXFP4
+                    // FIXME: keep a list of supported types to avoid breaking the backend when a new type is added
                     return false;
                 }
                 return true;
@@ -4361,6 +4357,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
             if (op->src[0]->ne[3] != 1) {
                 return false;
             }
+            // TODO: support attention sinks [TAG_ATTN_SINKS]
+            if (op->src[2]) {
+                return false;
+            }
             // TODO: support broadcast
             // ref: https://github.com/ggml-org/llama.cpp/pull/14435
             return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
index 3c1ae0849961ad8f8308f4daa4cd9f32771b8cf3..165933a72946539be9d68edbfaa212fc1e692a85 100644 (file)
@@ -449,6 +449,8 @@ struct vk_device_struct {
     vk_pipeline pipeline_div[2][2][2];
     vk_pipeline pipeline_div_norepeat[2][2][2];
 
+    vk_pipeline pipeline_add_id_f32;
+
     vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
     vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32;
     vk_pipeline pipeline_scale_f32;
@@ -483,6 +485,7 @@ struct vk_device_struct {
     vk_pipeline pipeline_geglu[2];
     vk_pipeline pipeline_reglu[2];
     vk_pipeline pipeline_swiglu[2];
+    vk_pipeline pipeline_swiglu_oai[2];
     vk_pipeline pipeline_geglu_erf[2];
     vk_pipeline pipeline_geglu_quick[2];
 
@@ -705,6 +708,8 @@ struct vk_op_glu_push_constants {
     uint32_t ne00;
     uint32_t ne20;
     uint32_t mode;  // 0: default, 1: swapped, 2: split
+    float alpha; // for swiglu_oai
+    float limit;
 };
 
 struct vk_op_unary_push_constants {
@@ -794,6 +799,15 @@ struct vk_op_binary_push_constants {
     float param1; float param2; int32_t param3;
 };
 
+struct vk_op_add_id_push_constants {
+    uint32_t ne0;
+    uint32_t ne1;
+    uint32_t s01;
+    uint32_t s02;
+    uint32_t s11;
+    uint32_t s21;
+};
+
 struct vk_op_diag_mask_push_constants {
     uint32_t ncols;
     uint32_t rows_per_channel;
@@ -835,6 +849,7 @@ struct vk_op_soft_max_push_constants {
     float m1;
     uint32_t n_head_log2;
     uint32_t nrows_x;
+    uint32_t has_sinks;
 };
 
 struct vk_op_argsort_push_constants {
@@ -1977,6 +1992,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
         break;
     case GGML_TYPE_IQ4_NL:
     case GGML_TYPE_IQ4_XS:
+    case GGML_TYPE_MXFP4:
         lut_size = 4*16;
         break;
     default:
@@ -2353,6 +2369,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
         CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S],   matmul_iq3_s_f16,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
         CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS],  matmul_iq4_xs_f16,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
         CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL],  matmul_iq4_nl_f16,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
+        CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_MXFP4],   matmul_mxfp4_f16,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
 
         CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
 #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
@@ -2379,6 +2396,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
         CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc,   matmul_id_iq3_s_f16,   , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
         CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc,  matmul_id_iq4_xs_f16,  , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
         CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc,  matmul_id_iq4_nl_f16,  , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc,   matmul_id_mxfp4_f16,   , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
 #undef CREATE_MM
 #undef CREATE_MM2
     } else
@@ -2440,6 +2458,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
             CREATE_MM2(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S],   matmul_iq3_s_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
             CREATE_MM2(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS],  matmul_iq4_xs_f32,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
             CREATE_MM2(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL],  matmul_iq4_nl_f32,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+            CREATE_MM2(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4],   matmul_mxfp4_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
         } else {
             CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
             CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@@ -2461,6 +2480,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
             CREATE_MM(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc,   matmul_iq3_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
             CREATE_MM(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc,  matmul_iq4_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
             CREATE_MM(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc,  matmul_iq4_nl_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+            CREATE_MM(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc,   matmul_mxfp4_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
         }
 
         CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
@@ -2493,6 +2513,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
             CREATE_MM(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc,   matmul_id_iq3_s_f32,   _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
             CREATE_MM(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc,  matmul_id_iq4_xs_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
             CREATE_MM(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc,  matmul_id_iq4_nl_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+            CREATE_MM(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc,   matmul_id_mxfp4_f32,   _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
         } else {
             CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
             CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
@@ -2514,6 +2535,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
             CREATE_MM(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc,   matmul_id_iq3_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
             CREATE_MM(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc,  matmul_id_iq4_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
             CREATE_MM(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc,  matmul_id_iq4_nl_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+            CREATE_MM(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc,   matmul_id_mxfp4_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
         }
 #undef CREATE_MM2
 #undef CREATE_MM
@@ -2581,6 +2603,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
         CREATE_MM2(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S],   matmul_iq3_s_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
         CREATE_MM2(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS],  matmul_iq4_xs_f32,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
         CREATE_MM2(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL],  matmul_iq4_nl_f32,  mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+        CREATE_MM2(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4],   matmul_mxfp4_f32,   mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 
 #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
         if (device->integer_dot_product) {
@@ -2618,6 +2641,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
         CREATE_MM(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc,   matmul_id_iq3_s_f32,   _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
         CREATE_MM(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc,  matmul_id_iq4_xs_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
         CREATE_MM(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc,  matmul_id_iq4_nl_f32,  _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+        CREATE_MM(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc,   matmul_id_mxfp4_f32,   _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
 #undef CREATE_MM2
 #undef CREATE_MMQ
 #undef CREATE_MM
@@ -2672,6 +2696,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
         CREATE_MM(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc,   matmul_iq3_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
         CREATE_MM(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc,  matmul_iq4_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
         CREATE_MM(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc,  matmul_iq4_nl_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+        CREATE_MM(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc,   matmul_mxfp4_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
 
 #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
         if (device->integer_dot_product) {
@@ -2709,6 +2734,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
         CREATE_MM(GGML_TYPE_IQ3_S,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc,   matmul_id_iq3_s_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
         CREATE_MM(GGML_TYPE_IQ4_XS,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc,  matmul_id_iq4_xs_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
         CREATE_MM(GGML_TYPE_IQ4_NL,  pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc,  matmul_id_iq4_nl_f32,  , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+        CREATE_MM(GGML_TYPE_MXFP4,   pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc,   matmul_id_mxfp4_f32,   , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
     }
     // reusing CREATE_MM from the fp32 path
     if ((device->coopmat2 || device->coopmat_support)
@@ -2767,6 +2793,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
         ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i],   "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1),   mul_mat_vec_iq3_s_f32_f32_len,   mul_mat_vec_iq3_s_f32_f32_data,   "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
         ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_XS][i],  "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(i+1),  mul_mat_vec_iq4_xs_f32_f32_len,  mul_mat_vec_iq4_xs_f32_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
         ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i],  "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1),  mul_mat_vec_iq4_nl_f32_f32_len,  mul_mat_vec_iq4_nl_f32_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_MXFP4][i],   "mul_mat_vec_mxfp4_f32_f32_"+std::to_string(i+1),   mul_mat_vec_mxfp4_f32_f32_len,   mul_mat_vec_mxfp4_f32_f32_data,   "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
 
         ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1),  mul_mat_vec_f32_f16_f32_len,  mul_mat_vec_f32_f16_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
         ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1),  mul_mat_vec_f16_f16_f32_len,  mul_mat_vec_f16_f16_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
@@ -2790,6 +2817,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
         ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i],   "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1),   mul_mat_vec_iq3_s_f16_f32_len,   mul_mat_vec_iq3_s_f16_f32_data,   "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
         ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_XS][i],  "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(i+1),  mul_mat_vec_iq4_xs_f16_f32_len,  mul_mat_vec_iq4_xs_f16_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
         ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i],  "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1),  mul_mat_vec_iq4_nl_f16_f32_len,  mul_mat_vec_iq4_nl_f16_f32_data,  "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
+        ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_MXFP4][i],   "mul_mat_vec_mxfp4_f16_f32_"+std::to_string(i+1),   mul_mat_vec_mxfp4_f16_f32_len,   mul_mat_vec_mxfp4_f16_f32_data,   "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
     }
 
     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32",  mul_mat_vec_id_f32_f32_len,  mul_mat_vec_id_f32_f32_data,  "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
@@ -2814,6 +2842,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S],   "mul_mat_vec_id_iq3_s_f32",   mul_mat_vec_id_iq3_s_f32_len,   mul_mat_vec_id_iq3_s_f32_data,   "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS],  "mul_mat_vec_id_iq4_xs_f32",  mul_mat_vec_id_iq4_xs_f32_len,  mul_mat_vec_id_iq4_xs_f32_data,  "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
     ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL],  "mul_mat_vec_id_iq4_nl_f32",  mul_mat_vec_id_iq4_nl_f32_len,  mul_mat_vec_id_iq4_nl_f32_data,  "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_MXFP4],   "mul_mat_vec_id_mxfp4_f32",   mul_mat_vec_id_mxfp4_f32_len,   mul_mat_vec_id_mxfp4_f32_data,   "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
 
     // dequant shaders
     ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16",   dequant_f32_len,  dequant_f32_data,  "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
@@ -2836,6 +2865,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_S],   "dequant_iq3_s",   dequant_iq3_s_len,   dequant_iq3_s_data,   "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS],  "dequant_iq4_xs",  dequant_iq4_xs_len,  dequant_iq4_xs_data,  "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL],  "dequant_iq4_nl",  dequant_iq4_nl_len,  dequant_iq4_nl_data,  "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4],   "dequant_mxfp4",   dequant_mxfp4_len,   dequant_mxfp4_data,   "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
 
     // get_rows
     ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32",  get_rows_f32_len,  get_rows_f32_data,  "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
@@ -2855,6 +2885,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_S],   "get_rows_iq3_s",   get_rows_iq3_s_len,   get_rows_iq3_s_data,   "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS],  "get_rows_iq4_xs",  get_rows_iq4_xs_len,  get_rows_iq4_xs_data,  "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL],  "get_rows_iq4_nl",  get_rows_iq4_nl_len,  get_rows_iq4_nl_data,  "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4],   "get_rows_mxfp4",   get_rows_mxfp4_len,   get_rows_mxfp4_data,   "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32",  get_rows_f32_f32_len,  get_rows_f32_f32_data,  "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32",  get_rows_f16_f32_len,  get_rows_f16_f32_data,  "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
@@ -2873,6 +2904,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_S],   "get_rows_iq3_s_f32",   get_rows_iq3_s_f32_len,   get_rows_iq3_s_f32_data,   "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS],  "get_rows_iq4_xs_f32",  get_rows_iq4_xs_f32_len,  get_rows_iq4_xs_f32_data,  "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL],  "get_rows_iq4_nl_f32",  get_rows_iq4_nl_f32_len,  get_rows_iq4_nl_f32_data,  "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4],   "get_rows_mxfp4_f32",   get_rows_mxfp4_f32_len,   get_rows_mxfp4_f32_data,   "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 4 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
@@ -2976,6 +3008,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
     CREATE_BINARY(div, _norepeat, {1})
 #undef CREATE_BINARY
 
+    ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1);
+
     ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
@@ -3026,6 +3060,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
     CREATE_GLU(geglu)
     CREATE_GLU(reglu)
     CREATE_GLU(swiglu)
+    CREATE_GLU(swiglu_oai)
     CREATE_GLU(geglu_erf)
     CREATE_GLU(geglu_quick)
 #undef CREATE_GLU
@@ -3035,10 +3070,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
     ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
 
-    ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
     ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
@@ -4244,6 +4279,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
         case GGML_TYPE_IQ3_S:
         case GGML_TYPE_IQ4_XS:
         case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_MXFP4:
             break;
         default:
             return nullptr;
@@ -4314,6 +4350,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
         case GGML_TYPE_IQ3_S:
         case GGML_TYPE_IQ4_XS:
         case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_MXFP4:
             break;
         default:
             return nullptr;
@@ -4357,6 +4394,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
         case GGML_TYPE_IQ3_S:
         case GGML_TYPE_IQ4_XS:
         case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_MXFP4:
             break;
         default:
             return nullptr;
@@ -4411,6 +4449,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
         case GGML_TYPE_IQ3_S:
         case GGML_TYPE_IQ4_XS:
         case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_MXFP4:
             break;
         default:
             return nullptr;
@@ -4446,6 +4485,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
         case GGML_TYPE_IQ3_S:
         case GGML_TYPE_IQ4_XS:
         case GGML_TYPE_IQ4_NL:
+        case GGML_TYPE_MXFP4:
             break;
         default:
             return nullptr;
@@ -4631,6 +4671,7 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context&
     std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))");
     GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size());
     GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT);
+    GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size());
 
     vk::DescriptorSet& descriptor_set = ctx->descriptor_sets[ctx->descriptor_set_idx++];
     vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() };
@@ -6847,6 +6888,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
             break;
         }
         return nullptr;
+    case GGML_OP_ADD_ID:
+        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && src2->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_F32) {
+            return ctx->device->pipeline_add_id_f32;
+        }
+        return nullptr;
     case GGML_OP_CONCAT:
         if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
             return ctx->device->pipeline_concat_f32;
@@ -6992,6 +7038,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
                 return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
             case GGML_GLU_OP_SWIGLU:
                 return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
+            case GGML_GLU_OP_SWIGLU_OAI:
+                return ctx->device->pipeline_swiglu_oai[dst->type == GGML_TYPE_F16];
             case GGML_GLU_OP_GEGLU_ERF:
                 return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16];
             case GGML_GLU_OP_GEGLU_QUICK:
@@ -7007,6 +7055,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
         return nullptr;
     case GGML_OP_SOFT_MAX:
         GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
+        GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32);
 
         if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
             return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
@@ -7177,6 +7226,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
     case GGML_OP_SUB:
     case GGML_OP_MUL:
     case GGML_OP_DIV:
+    case GGML_OP_ADD_ID:
     case GGML_OP_CONCAT:
     case GGML_OP_UPSCALE:
     case GGML_OP_SQR:
@@ -7523,6 +7573,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
                 elements = { ne, 1, 1 };
             }
         } break;
+    case GGML_OP_ADD_ID:
+        {
+            elements = { (uint32_t)ne01, (uint32_t)ne02, 1 };
+        } break;
     case GGML_OP_SET_ROWS:
         {
             uint32_t ne = ggml_nelements(src0);
@@ -7562,8 +7616,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
         }
     }
 
-    if (op == GGML_OP_SOFT_MAX || op == GGML_OP_GLU) {
-        // Empty src1 is possible in soft_max, but the shader needs a buffer
+    if (op == GGML_OP_GLU) {
+        // Empty src1 is possible in glu, but the shader needs a buffer
         vk_subbuffer subbuf_y;
         if (use_src1) {
             subbuf_y = { d_Y, y_buf_offset, y_sz };
@@ -7573,6 +7627,24 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
 
         ggml_vk_sync_buffers(subctx);
         ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
+    } else if (op == GGML_OP_SOFT_MAX) {
+        // Empty src1 and src2 is possible in soft_max, but the shader needs a buffer
+        vk_subbuffer subbuf_y;
+        if (use_src1) {
+            subbuf_y = { d_Y, y_buf_offset, y_sz };
+        } else {
+            subbuf_y = { d_X, 0, x_sz };
+        }
+
+        vk_subbuffer subbuf_z;
+        if (use_src2) {
+            subbuf_z = { d_Z, z_buf_offset, z_sz };
+        } else {
+            subbuf_z = { d_X, 0, x_sz };
+        }
+
+        ggml_vk_sync_buffers(subctx);
+        ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
     } else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
         // Empty src2 is possible in rope, but the shader needs a buffer
         vk_subbuffer subbuf_z;
@@ -7701,6 +7773,21 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
     }, dryrun);
 }
 
+static void ggml_vk_add_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
+    const uint32_t src0_type_size = ggml_type_size(src0->type);
+    const uint32_t src1_type_size = ggml_type_size(src1->type);
+    const uint32_t src2_type_size = ggml_type_size(src2->type);
+
+    ggml_vk_op_f32<vk_op_add_id_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_ADD_ID, {
+        (uint32_t)dst->ne[0],
+        (uint32_t)dst->ne[1],
+        (uint32_t)src0->nb[1] / src0_type_size,
+        (uint32_t)src0->nb[2] / src0_type_size,
+        (uint32_t)src1->nb[1] / src1_type_size,
+        (uint32_t)src2->nb[1] / src2_type_size,
+    }, dryrun);
+}
+
 static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version, bool dryrun = false) {
     GGML_ASSERT(version == 6 || version == 7);
     int num_srcs = version == 6 ? 6 : 7;
@@ -8119,8 +8206,12 @@ static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, con
 }
 
 static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
+    const float * op_params_f = (const float *)dst->op_params;
+
     const bool swapped = (bool)dst->op_params[1];
     const bool split = src1 != nullptr;
+    const float alpha = op_params_f[2];
+    const float limit = op_params_f[3];
 
     GGML_ASSERT(ggml_is_contiguous(src0));
 
@@ -8134,7 +8225,15 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const
 
     const uint32_t mode = split ? 2 : (swapped ? 1 : 0);
 
-    ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, { (uint32_t)ggml_nelements(dst), (uint32_t)src0->ne[0], (uint32_t)dst->ne[0], mode }, dryrun);
+    ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU,
+        {
+            (uint32_t)ggml_nelements(dst),
+            (uint32_t)src0->ne[0],
+            (uint32_t)dst->ne[0],
+            mode,
+            alpha,
+            limit
+        }, dryrun);
 }
 
 static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
@@ -8142,7 +8241,7 @@ static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& sub
     ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
 }
 
-static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
+static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
     float * op_params = (float *)dst->op_params;
 
     float scale = op_params[0];
@@ -8164,7 +8263,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
     const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
-    ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
+    ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, {
         ncols,
         src1 != nullptr ? nrows_y : (uint32_t)0,
         (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
@@ -8174,6 +8273,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
         m0, m1,
         n_head_log2,
         nrows_x,
+        src2 != nullptr
     }, dryrun);
 }
 
@@ -9413,6 +9513,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
         case GGML_GLU_OP_GEGLU:
         case GGML_GLU_OP_REGLU:
         case GGML_GLU_OP_SWIGLU:
+        case GGML_GLU_OP_SWIGLU_OAI:
         case GGML_GLU_OP_GEGLU_ERF:
         case GGML_GLU_OP_GEGLU_QUICK:
             break;
@@ -9424,6 +9525,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
     case GGML_OP_REPEAT_BACK:
     case GGML_OP_GET_ROWS:
     case GGML_OP_ADD:
+    case GGML_OP_ADD_ID:
     case GGML_OP_ACC:
     case GGML_OP_SUB:
     case GGML_OP_MUL:
@@ -9578,6 +9680,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
     case GGML_OP_DIV:
         ggml_vk_div(ctx, compute_ctx, src0, src1, node, dryrun);
 
+        break;
+    case GGML_OP_ADD_ID:
+        ggml_vk_add_id(ctx, compute_ctx, src0, src1, src2, node, dryrun);
+
         break;
     case GGML_OP_CONCAT:
         ggml_vk_concat(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -9675,6 +9781,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
         case GGML_GLU_OP_GEGLU:
         case GGML_GLU_OP_REGLU:
         case GGML_GLU_OP_SWIGLU:
+        case GGML_GLU_OP_SWIGLU_OAI:
         case GGML_GLU_OP_GEGLU_ERF:
         case GGML_GLU_OP_GEGLU_QUICK:
             ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -9688,7 +9795,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
 
         break;
     case GGML_OP_SOFT_MAX:
-        ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun);
+        ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node, dryrun);
 
         break;
     case GGML_OP_SOFT_MAX_BACK:
@@ -9834,6 +9941,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
     case GGML_OP_SUB:
     case GGML_OP_MUL:
     case GGML_OP_DIV:
+    case GGML_OP_ADD_ID:
     case GGML_OP_CONCAT:
     case GGML_OP_UPSCALE:
     case GGML_OP_SCALE:
@@ -9903,6 +10011,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
         case GGML_GLU_OP_GEGLU:
         case GGML_GLU_OP_REGLU:
         case GGML_GLU_OP_SWIGLU:
+        case GGML_GLU_OP_SWIGLU_OAI:
         case GGML_GLU_OP_GEGLU_ERF:
         case GGML_GLU_OP_GEGLU_QUICK:
             buf = tensor->buffer;
@@ -10752,6 +10861,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 case GGML_GLU_OP_GEGLU:
                 case GGML_GLU_OP_REGLU:
                 case GGML_GLU_OP_SWIGLU:
+                case GGML_GLU_OP_SWIGLU_OAI:
                 case GGML_GLU_OP_GEGLU_ERF:
                 case GGML_GLU_OP_GEGLU_QUICK:
                     return ggml_is_contiguous(op->src[0]) &&
@@ -10797,6 +10907,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                     case GGML_TYPE_IQ3_S:
                     case GGML_TYPE_IQ4_XS:
                     case GGML_TYPE_IQ4_NL:
+                    case GGML_TYPE_MXFP4:
                         break;
                     default:
                         return false;
@@ -10834,6 +10945,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) {
                     return false;
                 }
+                // TODO: support attention sinks [TAG_ATTN_SINKS]
+                if (op->src[4]) {
+                    return false;
+                }
                 if (op->src[0]->type != GGML_TYPE_F32) {
                     return false;
                 }
@@ -10906,6 +11021,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                     case GGML_TYPE_IQ3_S:
                     case GGML_TYPE_IQ4_XS:
                     case GGML_TYPE_IQ4_NL:
+                    case GGML_TYPE_MXFP4:
                         return true;
                     default:
                         return false;
@@ -11004,6 +11120,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
             return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
                    (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
                    (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
+        case GGML_OP_ADD_ID:
+            return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->src[2]->type == GGML_TYPE_I32 &&
+                   op->type == GGML_TYPE_F32;
         case GGML_OP_SILU_BACK:
         case GGML_OP_RMS_NORM_BACK:
         case GGML_OP_SQR:
diff --git a/src/ggml-vulkan/vulkan-shaders/add_id.comp b/src/ggml-vulkan/vulkan-shaders/add_id.comp
new file mode 100644 (file)
index 0000000..3ae8f01
--- /dev/null
@@ -0,0 +1,42 @@
+#version 450
+
+#extension GL_EXT_control_flow_attributes : require
+
+#include "types.comp"
+
+layout (push_constant) uniform parameter
+{
+    uint ne0;
+    uint ne1;
+    uint s01;
+    uint s02;
+    uint s11;
+    uint s21;
+} p;
+
+#define BLOCK_SIZE 512
+
+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
+layout (binding = 2) readonly buffer Z {int32_t data_c[];};
+layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+    const uint i1 = gl_WorkGroupID.x;
+    const uint i2 = gl_WorkGroupID.y;
+
+    const uint i11 = data_c[i1 + i2 * p.s21];
+
+    const uint s1 = p.ne0;
+    const uint s2 = p.ne0 * p.ne1;
+
+    const uint d0 = i1 * s1 + i2 * s2;
+    const uint a0 = i1 * p.s01 + i2 * p.s02;
+    const uint b0 = i11 * p.s11;
+
+    for (uint i0 = gl_LocalInvocationID.x; i0 < p.ne0; i0 += BLOCK_SIZE) {
+        data_d[d0 + i0] = data_a[a0 + i0] + data_b[b0 + i0];
+    }
+}
index dbc7daa3328f6e322f2bfbcd312d28bb7bce991a..978d430030760cfa085c786d96ed0c3410499b26 100644 (file)
@@ -4,8 +4,8 @@
 #include "generic_unary_head.comp"
 #include "dequant_funcs.comp"
 
-#if defined(DATA_A_IQ4_NL)
-// 16 invocations needed for init_iq4nl_shmem
+#if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4)
+// 16 invocations needed for init_iq_shmem
 layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
 #else
 layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
index 0d9739d40609af24b9dbc6d4b8124dae815f02b4..d3127fbd9865cd5ec37920d034a54a43b44663a7 100644 (file)
@@ -434,6 +434,18 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
 }
 #endif
 
+#if defined(DATA_A_MXFP4)
+vec2 dequantize(uint ib, uint iqs, uint a_offset) {
+    const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
+    return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]);
+}
+vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
+    vec2 v0 = dequantize(ib, iqs, a_offset);
+    vec2 v1 = dequantize(ib, iqs + 1, a_offset);
+    return vec4(v0.x, v0.y, v1.x, v1.y);
+}
+#endif
+
 #if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
 vec2 get_dm(uint ib, uint a_offset) {
     return vec2(0, 0);
@@ -455,6 +467,12 @@ vec2 get_dm(uint ib, uint a_offset) {
 }
 #endif
 
+#if defined(DATA_A_MXFP4)
+vec2 get_dm(uint ib, uint a_offset) {
+    return vec2(e8m0_to_fp32(data_a[a_offset + ib].e), 0);
+}
+#endif
+
 #if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
 vec2 get_dm(uint ib, uint a_offset) {
     return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m));
index 9cb7da2daab5da60f63dc01580db0689fd42bbe2..706540fd8514cc71f1c4dbe7f593232c2e7134bc 100644 (file)
@@ -654,6 +654,25 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
 }
 #endif
 
+#if defined(DATA_A_MXFP4)
+layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufMXFP4 {
+   block_mxfp4 block;
+};
+
+float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
+{
+    const float d = e8m0_to_fp32(bl.block.e);
+    const uint idx = coordInBlock[1];
+    const uint iqs = idx & 0xF;
+    const uint shift = (idx & 0x10) >> 2;
+    uint32_t qs = bl.block.qs[iqs];
+    qs >>= shift;
+    qs &= 0xF;
+    float16_t ret = float16_t(kvalues_mxfp4[qs] * d);
+    return ret;
+}
+#endif
+
 #if defined(DATA_A_Q4_0)
 #define dequantFuncA dequantFuncQ4_0
 #elif defined(DATA_A_Q4_1)
@@ -696,4 +715,6 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
 #define dequantFuncA dequantFuncIQ4_XS
 #elif defined(DATA_A_IQ4_NL)
 #define dequantFuncA dequantFuncIQ4_NL
+#elif defined(DATA_A_MXFP4)
+#define dequantFuncA dequantFuncMXFP4
 #endif
diff --git a/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp b/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp
new file mode 100644 (file)
index 0000000..ee496e9
--- /dev/null
@@ -0,0 +1,32 @@
+#version 450
+
+#include "dequant_head.comp"
+
+layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {block_mxfp4 data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
+
+void main() {
+    const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
+
+    init_iq_shmem(gl_WorkGroupSize);
+
+    const uint tid = gl_LocalInvocationID.x % 64;
+    const uint il  = tid/32;
+    const uint ir  = tid%32;
+    const uint ib = 32*i + ir;
+    if (ib >= p.nel / 32) {
+        return;
+    }
+
+    const uint q_idx = 8*il;
+    const uint b_idx = 1024*i + 32*ir + q_idx;
+
+    const float d = e8m0_to_fp32(data_a[ib].e);
+
+    [[unroll]] for (uint l = 0; l < 8; ++l) {
+        data_b[b_idx + l +  0] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]);
+        data_b[b_idx + l + 16] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] >>  4]);
+    }
+}
index 004a61fc1625480e92ddbe68b76caf33f1604e53..51d70869d953cb3b2575591caa8176b5237ae188 100644 (file)
@@ -14,4 +14,6 @@ layout (push_constant) uniform parameter
     uint ne00;
     uint ne20;
     uint mode;
+    float alpha;
+    float limit;
 } p;
index f481549911b92bb02ce4727ffce565c71345bf9e..8c5114a79d23cb8529c23d0ea2ec4547b6ed5bb1 100644 (file)
@@ -747,6 +747,21 @@ void main() {
             buf_a[buf_idx + 1 ] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]) * d;
             buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)]) * d;
             buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_iq4nl[vui >> 12]) * d;
+#elif defined(DATA_A_MXFP4)
+            const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
+            const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;
+
+            const uint ib = idx / 8;
+            const uint iqs = (idx & 0x07) * 2;
+
+            const float d = e8m0_to_fp32(data_a[ib].e);
+            const uint vui = uint(data_a[ib].qs[iqs]);
+            const uint vui2 = uint(data_a[ib].qs[iqs+1]);
+
+            buf_a[buf_idx     ] = FLOAT_TYPE(kvalues_mxfp4[vui & 0xF] * d);
+            buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_mxfp4[vui >>  4] * d);
+            buf_a[buf_idx +  1] = FLOAT_TYPE(kvalues_mxfp4[vui2 & 0xF] * d);
+            buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_mxfp4[vui2 >>  4] * d);
 #endif
         }
         [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
index 63b15471bd3aaa00db343e5bd7925dc5986bcfa9..34e8db97704ee27078a0cbb2e520214ca6421e75 100644 (file)
@@ -92,6 +92,12 @@ FLOAT_TYPE get_d(uint ib) {
 }
 #endif
 
+#if defined(DATA_A_MXFP4)
+FLOAT_TYPE get_d(uint ib) {
+    return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e));
+}
+#endif
+
 #if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
 FLOAT_TYPE_VEC2 get_dm(uint ib) {
     return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
index 5bcd3b1e3ddc67989f95b8971041ddef2647cced..5f20a1ee7d5ac742e23d028def7b672417836c29 100644 (file)
@@ -20,6 +20,7 @@ layout (push_constant) uniform parameter
     float m1;
     uint n_head_log2;
     uint nrows_x;
+    uint has_sinks;
 } p;
 
 #include "types.comp"
@@ -29,7 +30,8 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
 layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
 layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
-layout (binding = 2) buffer D {D_TYPE data_d[];};
+layout (binding = 2) readonly buffer Z {float data_c[];};
+layout (binding = 3) buffer D {D_TYPE data_d[];};
 
 shared FLOAT_TYPE vals[BLOCK_SIZE];
 
@@ -60,13 +62,13 @@ void soft_max(uint num_iters) {
         const uint h = (rowx / p.ne01) % p.ne02; // head index
 
         const float base = h < p.n_head_log2 ? p.m0 : p.m1;
-        const uint   exp  = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1;
+        const uint   exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1;
 
         slope = pow(base, exp);
     }
 
     // Find max
-    FLOAT_TYPE max_val = uintBitsToFloat(0xFF800000);
+    FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02];
 
     // Cache values while we compute the max, so we don't need to read them
     // again when we're ready to compute exp(x-max).
@@ -148,6 +150,10 @@ void soft_max(uint num_iters) {
     }
     sum = vals[0];
 
+    if (p.has_sinks != 0) {
+        sum += FLOAT_TYPE(exp(FLOAT_TYPE(data_c[i02]) - max_val));
+    }
+
     FLOAT_TYPE rcpdivisor = 1.0/sum;
 
     [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
diff --git a/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp b/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp
new file mode 100644 (file)
index 0000000..970750e
--- /dev/null
@@ -0,0 +1,14 @@
+#version 450
+
+#include "glu_head.comp"
+
+float op(float a, float b) {
+    float xi = min(a, p.limit);
+    float gi = max(min(b, p.limit), -p.limit);
+
+    float out_glu = xi / (1.0f + exp(-xi * p.alpha));
+    out_glu = out_glu * (1.0f + gi);
+    return out_glu;
+}
+
+#include "glu_main.comp"
index 3bde717832b45115d4e538a5e2d44c086e1d0e2a..a36c33e26764bb13a643e3f56a8d9049ceeaf2b5 100644 (file)
@@ -1337,6 +1337,29 @@ struct block_iq4_nl_packed16
 #define A_TYPE_PACKED16 block_iq4_nl_packed16
 #endif
 
+#define QUANT_K_MXFP4 32
+#define QUANT_R_MXFP4 2
+
+struct block_mxfp4
+{
+    uint8_t e;
+    uint8_t qs[QUANT_K_MXFP4/2];
+};
+
+//struct block_mxfp4_packed16
+//{
+//    uint8_t e;
+//    uint16_t qs[QUANT_K_MXFP4/2/2];
+//};
+
+#if defined(DATA_A_MXFP4)
+#define QUANT_K QUANT_K_MXFP4
+#define QUANT_R QUANT_R_MXFP4
+#define QUANT_AUXF 1
+#define A_TYPE block_mxfp4
+//#define A_TYPE_PACKED16 block_mxfp4_packed16
+#endif
+
 #if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS)
 const int8_t kvalues_iq4nl_const[16] = {
     int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10),
@@ -1356,6 +1379,25 @@ void init_iq_shmem(uvec3 wgsize)
 }
 #endif
 
+#if defined(DATA_A_MXFP4)
+const FLOAT_TYPE kvalues_mxfp4_const[16] = {
+    FLOAT_TYPE(0.0f), FLOAT_TYPE(0.5f), FLOAT_TYPE(1.0f), FLOAT_TYPE(1.5f), FLOAT_TYPE(2.0f), FLOAT_TYPE(3.0f), FLOAT_TYPE(4.0f), FLOAT_TYPE(6.0f),
+    FLOAT_TYPE(-0.0f), FLOAT_TYPE(-0.5f), FLOAT_TYPE(-1.0f), FLOAT_TYPE(-1.5f), FLOAT_TYPE(-2.0f), FLOAT_TYPE(-3.0f), FLOAT_TYPE(-4.0f), FLOAT_TYPE(-6.0f)
+};
+
+shared FLOAT_TYPE kvalues_mxfp4[16];
+
+#define NEEDS_INIT_IQ_SHMEM
+void init_iq_shmem(uvec3 wgsize)
+{
+    // copy the table into shared memory and sync
+    for (uint i = gl_LocalInvocationIndex.x; i < kvalues_mxfp4.length(); i += wgsize.x) {
+        kvalues_mxfp4[i] = kvalues_mxfp4_const[i];
+    }
+    barrier();
+}
+#endif
+
 // returns the bfloat value in the low 16b.
 // See ggml_compute_fp32_to_bf16
 uint32_t fp32_to_bf16(float f)
@@ -1370,4 +1412,17 @@ float bf16_to_fp32(uint32_t u)
     return uintBitsToFloat(u << 16);
 }
 
+float e8m0_to_fp32(uint8_t x) {
+    uint32_t bits;
+
+    if (x == 0) {
+        bits = 0x00400000;
+    } else {
+        bits = x;
+        bits = bits << 23;
+    }
+
+    return uintBitsToFloat(bits);
+}
+
 #endif // !defined(GGML_TYPES_COMP)
index c6aa3ea4c79e03412d07875d94b984a96d0904b5..4cd94c51e3f21ae70aae78da831af2a7c014f795 100644 (file)
@@ -64,6 +64,7 @@ const std::vector<std::string> type_names = {
     "iq3_s",
     "iq4_xs",
     "iq4_nl",
+    "mxfp4",
     "bf16",
 };
 
@@ -118,7 +119,7 @@ void execute_command(const std::string& command, std::string& stdout_str, std::s
     CloseHandle(pi.hProcess);
     CloseHandle(pi.hThread);
 #else
-int stdout_pipe[2];
+    int stdout_pipe[2];
     int stderr_pipe[2];
 
     if (pipe(stdout_pipe) != 0 || pipe(stderr_pipe) != 0) {
@@ -362,7 +363,7 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
         std::string load_vec_quant = "2";
         if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
             load_vec_quant = "8";
-        else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl"))
+        else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4"))
             load_vec_quant = "4";
 
         if (tname == "bf16") {
@@ -602,6 +603,8 @@ void process_shaders() {
         string_to_spv("reglu_f32" + suffix,      "reglu.comp",       {{"A_TYPE", "float"},       {"D_TYPE", "float"},       {"RTE16", rte ? "1" : "0"}});
         string_to_spv("swiglu_f16" + suffix,     "swiglu.comp",      {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"},   {"RTE16", rte ? "1" : "0"}});
         string_to_spv("swiglu_f32" + suffix,     "swiglu.comp",      {{"A_TYPE", "float"},       {"D_TYPE", "float"},       {"RTE16", rte ? "1" : "0"}});
+        string_to_spv("swiglu_oai_f16" + suffix, "swiglu_oai.comp",  {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"},   {"RTE16", rte ? "1" : "0"}});
+        string_to_spv("swiglu_oai_f32" + suffix, "swiglu_oai.comp",  {{"A_TYPE", "float"},       {"D_TYPE", "float"},       {"RTE16", rte ? "1" : "0"}});
         string_to_spv("geglu_erf_f16" + suffix,  "geglu_erf.comp",   {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"},   {"RTE16", rte ? "1" : "0"}});
         string_to_spv("geglu_erf_f32" + suffix,  "geglu_erf.comp",   {{"A_TYPE", "float"},       {"D_TYPE", "float"},       {"RTE16", rte ? "1" : "0"}});
         string_to_spv("geglu_quick_f16" + suffix,"geglu_quick.comp", {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"},   {"RTE16", rte ? "1" : "0"}});
@@ -671,6 +674,8 @@ void process_shaders() {
 
     string_to_spv("roll_f32", "roll.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
 
+    string_to_spv("add_id_f32", "add_id.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
+
     for (auto &c : compiles) {
         c.wait();
     }
index 124cf3e8b6025e9f5f0a116f25ee30815ebef0b1..55a76f8248c09a85b6f5c82ae3376f47750220d8 100644 (file)
@@ -582,9 +582,6 @@ FILE * ggml_fopen(const char * fname, const char * mode) {
 #endif
 
 }
-static void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * GGML_RESTRICT x, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc);
-static void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc);
-static void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * GGML_RESTRICT x, size_t bx, ggml_bf16_t * GGML_RESTRICT y, size_t by, int nrc);
 
 static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
     [GGML_TYPE_I8] = {
@@ -690,6 +687,14 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
         .is_quantized             = true,
         .from_float_ref           = (ggml_from_float_t) quantize_row_q8_1_ref,
     },
+    [GGML_TYPE_MXFP4] = {
+        .type_name                = "mxfp4",
+        .blck_size                = QK_MXFP4,
+        .type_size                = sizeof(block_mxfp4),
+        .is_quantized             = true,
+        .to_float                 = (ggml_to_float_t) dequantize_row_mxfp4,
+        .from_float_ref           = (ggml_from_float_t)quantize_row_mxfp4_ref,
+    },
     [GGML_TYPE_Q2_K] = {
         .type_name                = "q2_K",
         .blck_size                = QK_K,
@@ -917,6 +922,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
 
     "DUP",
     "ADD",
+    "ADD_ID",
     "ADD1",
     "ACC",
     "SUB",
@@ -1010,13 +1016,14 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "GLU",
 };
 
-static_assert(GGML_OP_COUNT == 86, "GGML_OP_COUNT != 86");
+static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
 
     "x",
     "x+y",
+    "x[i]+y",
     "x+y",
     "view(x,nb,offset)+=y->x",
     "x-y",
@@ -1110,7 +1117,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "glu(x)",
 };
 
-static_assert(GGML_OP_COUNT == 86, "GGML_OP_COUNT != 86");
+static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -1140,11 +1147,12 @@ static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
     "REGLU",
     "GEGLU",
     "SWIGLU",
+    "SWIGLU_OAI",
     "GEGLU_ERF",
     "GEGLU_QUICK",
 };
 
-static_assert(GGML_GLU_OP_COUNT == 5, "GGML_GLU_OP_COUNT != 5");
+static_assert(GGML_GLU_OP_COUNT == 6, "GGML_GLU_OP_COUNT != 6");
 
 
 static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@@ -1312,6 +1320,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
         case GGML_FTYPE_MOSTLY_Q5_0:          wtype = GGML_TYPE_Q5_0;  break;
         case GGML_FTYPE_MOSTLY_Q5_1:          wtype = GGML_TYPE_Q5_1;  break;
         case GGML_FTYPE_MOSTLY_Q8_0:          wtype = GGML_TYPE_Q8_0;  break;
+        case GGML_FTYPE_MOSTLY_MXFP4:         wtype = GGML_TYPE_MXFP4; break;
         case GGML_FTYPE_MOSTLY_Q2_K:          wtype = GGML_TYPE_Q2_K;  break;
         case GGML_FTYPE_MOSTLY_Q3_K:          wtype = GGML_TYPE_Q3_K;  break;
         case GGML_FTYPE_MOSTLY_Q4_K:          wtype = GGML_TYPE_Q4_K;  break;
@@ -1962,6 +1971,27 @@ struct ggml_tensor * ggml_add_cast(
     return ggml_add_cast_impl(ctx, a, b, type);
 }
 
+struct ggml_tensor * ggml_add_id(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * b,
+            struct ggml_tensor  * ids) {
+
+    GGML_ASSERT(a->ne[0] == b->ne[0]);
+    GGML_ASSERT(a->ne[1] == ids->ne[0]);
+    GGML_ASSERT(a->ne[2] == ids->ne[1]);
+    GGML_ASSERT(ids->type == GGML_TYPE_I32);
+
+    struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+
+    result->op     = GGML_OP_ADD_ID;
+    result->src[0] = a;
+    result->src[1] = b;
+    result->src[2] = ids;
+
+    return result;
+}
+
 // ggml_add1
 
 static struct ggml_tensor * ggml_add1_impl(
@@ -2812,6 +2842,19 @@ struct ggml_tensor * ggml_geglu_quick_split(
     return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_QUICK, false);
 }
 
+struct ggml_tensor * ggml_swiglu_oai(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        float                 alpha,
+        float                 limit) {
+    struct ggml_tensor * result = ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU_OAI, false);
+    ggml_set_op_params_f32(result, 2, alpha);
+    ggml_set_op_params_f32(result, 3, limit);
+
+    return result;
+}
+
 // ggml_norm
 
 static struct ggml_tensor * ggml_norm_impl(
@@ -3779,6 +3822,22 @@ struct ggml_tensor * ggml_soft_max_ext(
     return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
 }
 
+void ggml_soft_max_add_sinks(
+        struct ggml_tensor * a,
+        struct ggml_tensor * sinks) {
+    if (!sinks) {
+        a->src[2] = NULL;
+        return;
+    }
+
+    GGML_ASSERT(a->op == GGML_OP_SOFT_MAX);
+    GGML_ASSERT(a->src[2] == NULL);
+    GGML_ASSERT(a->src[0]->ne[2] == sinks->ne[0]);
+    GGML_ASSERT(sinks->type == GGML_TYPE_F32);
+
+    a->src[2] = sinks;
+}
+
 // ggml_soft_max_ext_back
 
 static struct ggml_tensor * ggml_soft_max_ext_back_impl(
@@ -4812,6 +4871,22 @@ enum ggml_prec ggml_flash_attn_ext_get_prec(
     return (enum ggml_prec) prec_i32;
 }
 
+void ggml_flash_attn_ext_add_sinks(
+        struct ggml_tensor * a,
+        struct ggml_tensor * sinks) {
+    if (!sinks) {
+        a->src[4] = NULL;
+        return;
+    }
+
+    GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
+    GGML_ASSERT(a->src[4] == NULL);
+    GGML_ASSERT(a->src[0]->ne[2] == sinks->ne[0]);
+    GGML_ASSERT(sinks->type == GGML_TYPE_F32);
+
+    a->src[4] = sinks;
+}
+
 // ggml_flash_attn_back
 
 struct ggml_tensor * ggml_flash_attn_back(
@@ -6872,6 +6947,7 @@ size_t ggml_quantize_chunk(
         case GGML_TYPE_Q5_0:    result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
         case GGML_TYPE_Q5_1:    result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
         case GGML_TYPE_Q8_0:    result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+        case GGML_TYPE_MXFP4:   result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
         case GGML_TYPE_Q2_K:    result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
         case GGML_TYPE_Q3_K:    result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
         case GGML_TYPE_Q4_K:    result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
index ea65f1a2ee4dd2d2c625c90b8b045b2ee0cf51d7..d29779cd12b229eddef9504b495c0e41028b4653 100644 (file)
@@ -296,6 +296,7 @@ static std::string var_to_str(ggml_scale_mode mode) {
 #define VARS_TO_STR10(a, b, c, d, e, f, g, h, i, j) VAR_TO_STR(a) + "," + VARS_TO_STR9(b, c, d, e, f, g, h, i, j)
 #define VARS_TO_STR11(a, b, c, d, e, f, g, h, i, j, k) VAR_TO_STR(a) + "," + VARS_TO_STR10(b, c, d, e, f, g, h, i, j, k)
 #define VARS_TO_STR12(a, b, c, d, e, f, g, h, i, j, k, l) VAR_TO_STR(a) + "," + VARS_TO_STR11(b, c, d, e, f, g, h, i, j, k, l)
+#define VARS_TO_STR13(a, b, c, d, e, f, g, h, i, j, k, l, m) VAR_TO_STR(a) + "," + VARS_TO_STR12(b, c, d, e, f, g, h, i, j, k, l, m)
 
 #ifdef GGML_USE_SYCL
 static bool inline _isinf(float f) {
@@ -1886,6 +1887,66 @@ struct test_glu_split : public test_case {
     }
 };
 
+struct test_swiglu_oai : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne_a;
+    int v; // view (1 : non-contiguous a)
+    float alpha;
+    float limit;
+
+    std::string vars() override {
+        return VARS_TO_STR5(type, ne_a, v, alpha, limit);
+    }
+
+    test_swiglu_oai(ggml_type type = GGML_TYPE_F32,
+                    std::array<int64_t, 4> ne_a = {128, 2, 2, 2},
+                    int v = 0,
+                    float alpha = 1.702f,
+                    float limit = 7.0f)
+        : type(type), ne_a(ne_a), v(v), alpha(alpha), limit(limit) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a;
+        ggml_tensor * b;
+        if (v & 1) {
+            auto ne = ne_a; ne[0] *= 3;
+            a = ggml_new_tensor(ctx, type, 4, ne.data());
+            ggml_set_param(a);
+            ggml_set_name(a, "a");
+
+            a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
+            ggml_set_name(a, "view_of_a");
+
+            b = ggml_new_tensor(ctx, type, 4, ne.data());
+            ggml_set_param(b);
+            ggml_set_name(b, "b");
+
+            b = ggml_view_4d(ctx, b, ne_a[0], ne_a[1], ne_a[2], ne_a[3], b->nb[1], b->nb[2], b->nb[3], 0);
+            ggml_set_name(a, "view_of_b");
+        } else {
+            a = ggml_new_tensor(ctx, type, 4, ne_a.data());
+            ggml_set_param(a);
+            ggml_set_name(a, "a");
+
+            b = ggml_new_tensor(ctx, type, 4, ne_a.data());
+            ggml_set_param(b);
+            ggml_set_name(b, "b");
+        }
+
+        ggml_tensor * out = ggml_swiglu_oai(ctx, a, b, alpha, limit);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            // test extended range of values to check for NaNs in GELU
+            init_tensor_uniform(t, -150.f, 150.f);
+        }
+    }
+};
+
 // GGML_OP_GET_ROWS
 struct test_get_rows : public test_case {
     const ggml_type type;
@@ -2483,6 +2544,68 @@ struct test_bin_bcast : public test_case {
     }
 };
 
+// GGML_OP_ADD_ID
+struct test_add_id : public test_case {
+    const ggml_type type_a;
+    const ggml_type type_b;
+    const int64_t n_embd;
+    const int64_t n_experts;
+    const int64_t n_experts_used;
+    const int64_t n_token;
+
+    std::string vars() override {
+        return VARS_TO_STR6(type_a, type_b, n_embd, n_experts, n_experts_used, n_token);
+    }
+
+    size_t op_size(ggml_tensor * t) override {
+        return ggml_nbytes(t) + ggml_nbytes(t->src[0]) + ggml_nbytes(t->src[2]);
+    }
+
+    test_add_id(ggml_type type_a = GGML_TYPE_F32,
+            ggml_type type_b = GGML_TYPE_F32,
+            int64_t n_embd = 128,
+            int64_t n_experts = 16,
+            int64_t n_experts_used = 8,
+            int64_t n_token = 10)
+        : type_a(type_a), type_b(type_b), n_embd(n_embd),
+          n_experts(n_experts), n_experts_used(n_experts_used), n_token(n_token) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * a = ggml_new_tensor_3d(ctx, type_a, n_embd, n_experts_used, n_token);
+        ggml_tensor * b = ggml_new_tensor_2d(ctx, type_b, n_embd, n_experts);
+        ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_experts, n_token);
+        if (n_experts_used != n_experts) {
+            ids = ggml_view_2d(ctx, ids, n_experts_used, n_token, ids->nb[1], 0);
+            ggml_set_name(ids, "view_of_ids");
+        }
+
+        ggml_tensor * out = ggml_add_id(ctx, a, b, ids);
+        ggml_set_name(out, "out");
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            if (t->type == GGML_TYPE_I32) {
+                if (ggml_is_view_op(t->op)) { continue; }
+                std::random_device rd;
+                std::default_random_engine rng(rd());
+                // ids
+                for (int64_t r = 0; r < ggml_nrows(t); r++) {
+                    std::vector<int32_t> data(t->ne[0]);
+                    for (int i = 0; i < t->ne[0]; i++) {
+                        data[i] = i % n_experts;
+                    }
+                    std::shuffle(data.begin(), data.end(), rng);
+                    ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
+                }
+            } else {
+                init_tensor_uniform(t);
+            }
+        }
+    }
+};
+
 // GGML_OP_ADD1
 struct test_add1 : public test_case {
     const ggml_type type;
@@ -3122,11 +3245,11 @@ struct test_mul_mat_id : public test_case {
     }
 
     void initialize_tensors(ggml_context * ctx) override {
-        std::random_device rd;
-        std::default_random_engine rng(rd());
         for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
             if (t->type == GGML_TYPE_I32) {
                 if (ggml_is_view_op(t->op)) { continue; }
+                std::random_device rd;
+                std::default_random_engine rng(rd());
                 // ids
                 for (int64_t r = 0; r < ggml_nrows(t); r++) {
                     std::vector<int32_t> data(t->ne[0]);
@@ -3447,13 +3570,14 @@ struct test_soft_max : public test_case {
     const ggml_type type;
     const std::array<int64_t, 4> ne;
     const bool mask;
+    const bool sinks;
     const ggml_type m_prec;
     const std::array<int64_t, 2> nr23; // broadcast only dims 2 and 3
     const float scale;
     const float max_bias;
 
     std::string vars() override {
-        return VARS_TO_STR7(type, ne, mask, m_prec, nr23, scale, max_bias);
+        return VARS_TO_STR8(type, ne, mask, sinks, m_prec, nr23, scale, max_bias);
     }
 
     // the 1024 test with bias occasionally fails:
@@ -3465,11 +3589,12 @@ struct test_soft_max : public test_case {
     test_soft_max(ggml_type type = GGML_TYPE_F32,
             std::array<int64_t, 4> ne = {10, 5, 4, 3},
             bool mask = false,
+            bool sinks = false,
             ggml_type m_prec = GGML_TYPE_F32,
             std::array<int64_t, 2> nr23 = {1, 1},
             float scale = 1.0f,
             float max_bias = 0.0f)
-        : type(type), ne(ne), mask(mask), m_prec(m_prec), nr23(nr23), scale(scale), max_bias(max_bias) {}
+        : type(type), ne(ne), mask(mask), sinks(sinks), m_prec(m_prec), nr23(nr23), scale(scale), max_bias(max_bias) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2]*nr23[0], ne[3]*nr23[1]);
@@ -3482,7 +3607,14 @@ struct test_soft_max : public test_case {
             ggml_set_name(mask, "mask");
         }
 
+        ggml_tensor * sinks = nullptr;
+        if (this->sinks) {
+            sinks = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[2]*nr23[0]);
+            ggml_set_name(sinks, "sinks");
+        }
+
         ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, scale, max_bias);
+        ggml_soft_max_add_sinks(out, sinks);
         ggml_set_name(out, "out");
 
         return out;
@@ -4433,6 +4565,7 @@ struct test_flash_attn_ext : public test_case {
     const int64_t nb; // batch size
 
     const bool mask; // use mask
+    const bool sinks; // use sinks
 
     const float max_bias; // ALiBi
     const float logit_softcap; // Gemma 2
@@ -4442,7 +4575,7 @@ struct test_flash_attn_ext : public test_case {
     std::array<int32_t, 4> permute;
 
     std::string vars() override {
-        return VARS_TO_STR12(hsk, hsv, nh, nr23, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute);
+        return VARS_TO_STR13(hsk, hsv, nh, nr23, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV, permute);
     }
 
     double max_nmse_err() override {
@@ -4457,9 +4590,9 @@ struct test_flash_attn_ext : public test_case {
     }
 
     test_flash_attn_ext(int64_t hsk = 128, int64_t hsv = 128, int64_t nh = 32, std::array<int64_t, 2> nr23 = {1, 1}, int64_t kv = 96, int64_t nb = 8,
-                        bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32,
+                        bool mask = true, bool sinks = false, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32,
                         ggml_type type_KV = GGML_TYPE_F16, std::array<int32_t, 4> permute = {0, 1, 2, 3})
-        : hsk(hsk), hsv(hsv), nh(nh), nr23(nr23), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
+        : hsk(hsk), hsv(hsv), nh(nh), nr23(nr23), kv(kv), nb(nb), mask(mask), sinks(sinks), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         const int64_t hsk_padded = GGML_PAD(hsk, ggml_blck_size(type_KV));
@@ -4499,13 +4632,31 @@ struct test_flash_attn_ext : public test_case {
             ggml_set_name(m, "m");
         }
 
+        ggml_tensor * s = nullptr;
+        if (sinks) {
+            s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, q->ne[2]);
+            ggml_set_name(s, "s");
+        }
+
         ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hsk), max_bias, logit_softcap);
-        ggml_flash_attn_ext_set_prec(out, prec);
+        ggml_flash_attn_ext_add_sinks(out, s);
+        ggml_flash_attn_ext_set_prec (out, prec);
         ggml_set_name(out, "out");
 
         return out;
     }
 
+    void initialize_tensors(ggml_context * ctx) override {
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            if (strcmp(t->name, "s") == 0) {
+                // make the sink values more noticable in order to trigger a test failure when the implementation is wrong
+                init_tensor_uniform(t, -10.0f, 10.0f);
+            } else {
+                init_tensor_uniform(t);
+            }
+        }
+    }
+
     bool grad_precise() override {
         return true;
     }
@@ -5038,6 +5189,7 @@ static const ggml_type all_types[] = {
     GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
     GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
     GGML_TYPE_Q8_0,
+    GGML_TYPE_MXFP4,
     GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
     GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
     GGML_TYPE_Q6_K,
@@ -5053,6 +5205,7 @@ static const ggml_type base_types[] = {
     GGML_TYPE_Q4_0,
     GGML_TYPE_Q4_1, // for I8MM tests
     GGML_TYPE_Q4_K,
+    GGML_TYPE_MXFP4, // TODO: or "other"
     GGML_TYPE_IQ2_XXS
 };
 
@@ -5089,6 +5242,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
         for (int v : {0, 1}) {
             for (int op = 0; op < GGML_GLU_OP_COUNT; op++) {
+                if (op == GGML_GLU_OP_SWIGLU_OAI) {
+                    // SWIGLU_OAI is handled separately
+                    continue;
+                }
+
                 for (bool swapped : {false, true}) {
                     test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 128, 2, 2, 2 }, v, swapped));
                     test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 5, 7, 11, 13 }, v, swapped));
@@ -5100,6 +5258,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         }
     }
 
+    for (int v : {0, 1}) {
+        for (float alpha : {.5f, 1.702f}) {
+            for (float limit : {2.0f, 7.0f}) {
+                test_cases.emplace_back(new test_swiglu_oai(GGML_TYPE_F32, { 128, 2, 2, 2 }, v, alpha, limit));
+            }
+        }
+    }
+
     test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, false));
     for (ggml_type type : all_types) {
         for (int b : {1, 7}) {
@@ -5668,6 +5834,21 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         }
     }
 
+    // add_id
+    for (ggml_type type_a : {GGML_TYPE_F32}) {
+        for (ggml_type type_b : {GGML_TYPE_F32}) {
+            for (int n_mats : {4, 8}) {
+                for (int n_used : {1, 2, 4}) {
+                    for (int n_embd : {32, 129}) {
+                        for (int n_token : {1, 32, 129}) {
+                            test_cases.emplace_back(new test_add_id(type_a, type_b, n_embd, n_mats, n_used, n_token));
+                        }
+                    }
+                }
+            }
+        }
+    }
+
     for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
         test_cases.emplace_back(new test_sqr(type));
         test_cases.emplace_back(new test_sqrt(type));
@@ -5697,38 +5878,40 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     }
 #endif
     for (bool mask : {false, true}) {
-        for (float max_bias : {0.0f, 8.0f}) {
-            if (!mask && max_bias > 0.0f) continue;
-            for (float scale : {1.0f, 0.1f}) {
-                for (int64_t ne0 : {16, 1024}) {
-                    for (int64_t ne1 : {16, 1024}) {
-                        if (mask) {
-                            for (ggml_type m_prec : {GGML_TYPE_F32, GGML_TYPE_F16}) {
-                                test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0,   ne1,   1, 1}, mask, m_prec, {1, 1}, scale, max_bias));
-                                test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {1, 1}, scale, max_bias));
-
-                                if (ne0 <= 32 && ne1 <= 32) {
-                                    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0,   ne1,   1, 3}, mask, m_prec, {3, 1}, scale, max_bias));
-                                    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {2, 3}, scale, max_bias));
+        for (bool sinks : {false, true}) {
+            for (float max_bias : {0.0f, 8.0f}) {
+                if (!mask && max_bias > 0.0f) continue;
+                for (float scale : {1.0f, 0.1f}) {
+                    for (int64_t ne0 : {16, 1024}) {
+                        for (int64_t ne1 : {16, 1024}) {
+                            if (mask) {
+                                for (ggml_type m_prec : {GGML_TYPE_F32, GGML_TYPE_F16}) {
+                                    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0,   ne1,   1, 1}, mask, sinks, m_prec, {1, 1}, scale, max_bias));
+                                    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, sinks, m_prec, {1, 1}, scale, max_bias));
+
+                                    if (ne0 <= 32 && ne1 <= 32) {
+                                        test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0,   ne1,   1, 3}, mask, sinks, m_prec, {3, 1}, scale, max_bias));
+                                        test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, sinks, m_prec, {2, 3}, scale, max_bias));
+                                    }
                                 }
+                            } else {
+                                /* The precision of mask here doesn't matter as boolean mask is false */
+                                test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0,   ne1,   1, 1}, mask, sinks, GGML_TYPE_F32, {1, 1}, scale, max_bias));
+                                test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, sinks, GGML_TYPE_F32, {1, 1}, scale, max_bias));
                             }
-                        } else {
-                            /* The precision of mask here doesn't matter as boolean mask is false */
-                            test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0,   ne1,   1, 1}, mask, GGML_TYPE_F32, {1, 1}, scale, max_bias));
-                            test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, GGML_TYPE_F32, {1, 1}, scale, max_bias));
                         }
                     }
                 }
             }
         }
     }
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true,  GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true,  GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true,  true,  GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true,  false, GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, true,  GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  true,  GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  false, GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  true,  GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  true,  GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f));
 
     for (float max_bias : {0.0f, 8.0f}) {
         for (float scale : {1.0f, 0.1f}) {
@@ -5832,27 +6015,29 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
             if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA
 
             for (bool mask : { true, false } ) {
-                for (float max_bias : { 0.0f, 8.0f }) {
-                    if (!mask && max_bias > 0.0f) continue;
-                    for (float logit_softcap : {0.0f, 10.0f}) {
-                        if (hsk != 128 && logit_softcap != 0.0f) continue;
-                        for (int nh : { 4, }) {
-                            for (int nr3 : { 1, 3, }) {
-                                if (hsk > 64 && nr3 > 1) continue; // skip broadcast for large head sizes
-                                for (int nr2 : { 1, 4, 16 }) {
-                                    if (nr2 == 16 && hsk != 128) continue;
-                                    for (int kv : { 512, 1024, }) {
-                                        if (nr2 != 1 && kv != 512) continue;
-                                        for (int nb : { 1, 3, 32, 35, }) {
-                                            for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
-                                                if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
-                                                for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
-                                                    test_cases.emplace_back(new test_flash_attn_ext(
-                                                                hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, max_bias, logit_softcap, prec, type_KV));
-                                                    // run fewer test cases permuted
-                                                    if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
+                for (bool sinks : { true, false } ) {
+                    for (float max_bias : { 0.0f, 8.0f }) {
+                        if (!mask && max_bias > 0.0f) continue;
+                        for (float logit_softcap : {0.0f, 10.0f}) {
+                            if (hsk != 128 && logit_softcap != 0.0f) continue;
+                            for (int nh : { 4, }) {
+                                for (int nr3 : { 1, 3, }) {
+                                    if (hsk > 64 && nr3 > 1) continue; // skip broadcast for large head sizes
+                                    for (int nr2 : { 1, 4, 16 }) {
+                                        if (nr2 == 16 && hsk != 128) continue;
+                                        for (int kv : { 512, 1024, }) {
+                                            if (nr2 != 1 && kv != 512) continue;
+                                            for (int nb : { 1, 3, 32, 35, }) {
+                                                for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
+                                                    if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
+                                                    for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
                                                         test_cases.emplace_back(new test_flash_attn_ext(
-                                                                    hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
+                                                                    hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV));
+                                                        // run fewer test cases permuted
+                                                        if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
+                                                            test_cases.emplace_back(new test_flash_attn_ext(
+                                                                        hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
+                                                        }
                                                     }
                                                 }
                                             }
@@ -5937,14 +6122,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
     test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3}));
     test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3}));
 
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 256, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 256, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
 
     test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 10, 1, 1}));
     test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
@@ -5976,7 +6161,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
     for (int kv : { 4096, 8192, 16384, }) {
         for (int hs : { 64, 128, }) {
             for (int nr : { 1, 4, }) {
-                test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, {nr, 1}, kv, 1, true, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
+                test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, {nr, 1}, kv, 1, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
             }
         }
     }
@@ -5988,6 +6173,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
 
     test_cases.emplace_back(new test_mean(GGML_TYPE_F32, {256, 256, 3, 1}));
 
+
+    for (int n_token : {1, 512}) {
+        test_cases.emplace_back(new test_add_id(GGML_TYPE_F32, GGML_TYPE_F32, 2880, 128, 4, n_token));
+        test_cases.emplace_back(new test_add_id(GGML_TYPE_F32, GGML_TYPE_F32, 2880, 32, 4, n_token));
+    }
+
     return test_cases;
 }