GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
+#define GGML_TENSOR_TERNARY_OP_LOCALS \
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
+ GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
+ GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
+ GGML_TENSOR_LOCALS(int64_t, ne2, src2, ne) \
+ GGML_TENSOR_LOCALS(size_t, nb2, src2, nb) \
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
+
#define GGML_TENSOR_BINARY_OP_LOCALS01 \
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
// GGML_TYPE_IQ4_NL_4_4 = 36,
// GGML_TYPE_IQ4_NL_4_8 = 37,
// GGML_TYPE_IQ4_NL_8_8 = 38,
- GGML_TYPE_COUNT = 39,
+ GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block)
+ GGML_TYPE_COUNT = 40,
};
// precision
GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
+ GGML_FTYPE_MOSTLY_MXFP4 = 25, // except 1d tensors
};
// available tensor operations:
GGML_OP_DUP,
GGML_OP_ADD,
+ GGML_OP_ADD_ID,
GGML_OP_ADD1,
GGML_OP_ACC,
GGML_OP_SUB,
GGML_GLU_OP_REGLU,
GGML_GLU_OP_GEGLU,
GGML_GLU_OP_SWIGLU,
+ GGML_GLU_OP_SWIGLU_OAI,
GGML_GLU_OP_GEGLU_ERF,
GGML_GLU_OP_GEGLU_QUICK,
struct ggml_tensor * b,
enum ggml_type type);
+ // dst[i0, i1, i2] = a[i0, i1, i2] + b[i0, ids[i1, i2]]
+ GGML_API struct ggml_tensor * ggml_add_id(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ struct ggml_tensor * ids);
+
GGML_API struct ggml_tensor * ggml_add1(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * a,
struct ggml_tensor * b);
+ GGML_API struct ggml_tensor * ggml_swiglu_oai(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ float alpha,
+ float limit);
+
// normalize along rows
GGML_API struct ggml_tensor * ggml_norm(
struct ggml_context * ctx,
float scale,
float max_bias);
+ GGML_API void ggml_soft_max_add_sinks(
+ struct ggml_tensor * a,
+ struct ggml_tensor * sinks);
+
GGML_API struct ggml_tensor * ggml_soft_max_ext_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
GGML_API enum ggml_prec ggml_flash_attn_ext_get_prec(
const struct ggml_tensor * a);
+ GGML_API void ggml_flash_attn_ext_add_sinks(
+ struct ggml_tensor * a,
+ struct ggml_tensor * sinks);
+
// TODO: needs to be adapted to ggml_flash_attn_ext
GGML_API struct ggml_tensor * ggml_flash_attn_back(
struct ggml_context * ctx,
case GGML_OP_DIAG_MASK_ZERO:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_ADD:
+ case GGML_OP_ADD_ID:
case GGML_OP_ADD1:
case GGML_OP_SUB:
case GGML_OP_MUL:
memcpy(&bias, (float*)op->op_params + 1, sizeof(float));
return bias == 0.0f; // TODO: support bias != 0.0f
case GGML_OP_SOFT_MAX:
+ // TODO: support attention sinks [TAG_ATTN_SINKS]
+ if (op->src[2]) {
+ return false;
+ }
// TODO: support broadcast
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
if(op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_BF16){
return false;
}
+ // TODO: support attention sinks [TAG_ATTN_SINKS]
+ if (op->src[4]) {
+ return false;
+ }
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
// different head sizes of K and V are not supported yet
return false;
#define QI4_1 (QK4_1 / (4 * QR4_1))
#define QR4_1 2
+#define QI_MXFP4 (QK_MXFP4 / (4 * QR_MXFP4))
+#define QR_MXFP4 2
+
#define QI5_0 (QK5_0 / (4 * QR5_0))
#define QR5_0 2
} block_q4_1;
static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding");
+#define QK_MXFP4 32
+typedef struct {
+ uint8_t e; // E8M0
+ uint8_t qs[QK_MXFP4/2];
+} block_mxfp4;
+static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + QK_MXFP4/2, "wrong mxfp4 block size/padding");
+
#define QK5_0 32
typedef struct {
ggml_half d; // delta
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
GGML_TABLE_END()
+// TODO: fix name to kvalues_iq4_nl
GGML_TABLE_BEGIN(int8_t, kvalues_iq4nl, 16)
-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113,
GGML_TABLE_END()
+// e2m1 values (doubled)
+// ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
+GGML_TABLE_BEGIN(int8_t, kvalues_mxfp4, 16)
+ 0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12,
+GGML_TABLE_END()
+
#define NGRID_IQ1S 2048
#define IQ1S_DELTA 0.125f
#define IQ1M_DELTA 0.125f
#define ggml_vec_dot_q5_0_q8_0_generic ggml_vec_dot_q5_0_q8_0
#define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1
#define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0
+#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
+#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
+#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
+#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
#define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K
#define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
+#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
+#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
*s = sumf;
}
+void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
+ assert(n % QK_MXFP4 == 0);
+ static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
+
+ const block_mxfp4 * GGML_RESTRICT x = vx;
+ const block_q8_0 * GGML_RESTRICT y = vy;
+
+ const int nb = n / QK_MXFP4;
+
+ int ib = 0;
+ float sumf = 0;
+
+#if defined __ARM_NEON
+ const int8x16_t values = vld1q_s8(kvalues_mxfp4);
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
+ uint8x16x2_t q4bits;
+ int8x16x4_t q4b;
+ int8x16x4_t q8b;
+ int32x4_t prod_1;
+ int32x4_t prod_2;
+
+ for (; ib + 1 < nb; ib += 2) {
+ q4bits.val[0] = vld1q_u8(x[ib + 0].qs);
+ q4bits.val[1] = vld1q_u8(x[ib + 1].qs);
+ q8b.val[0] = vld1q_s8(y[ib + 0].qs);
+ q8b.val[1] = vld1q_s8(y[ib + 0].qs + 16);
+ q8b.val[2] = vld1q_s8(y[ib + 1].qs);
+ q8b.val[3] = vld1q_s8(y[ib + 1].qs + 16);
+
+ q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b));
+ q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
+ q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b));
+ q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));
+
+ prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);
+ prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);
+
+ sumf +=
+ GGML_E8M0_TO_FP32_HALF(x[ib + 0].e) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * vaddvq_s32(prod_1) +
+ GGML_E8M0_TO_FP32_HALF(x[ib + 1].e) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * vaddvq_s32(prod_2);
+ }
+
+#endif
+ for (; ib < nb; ++ib) {
+ const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
+ int sumi1 = 0;
+ int sumi2 = 0;
+ for (int j = 0; j < QK_MXFP4/2; ++j) {
+ sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
+ sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4];
+ }
+ sumf += d * (sumi1 + sumi2);
+ }
+ *s = sumf;
+}
+
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
const int qk = QK8_0;
const int nb = n / qk;
}
#if defined(__AVX2__) || defined(__AVX512F__)
+static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
+ const __m256i ax = _mm256_sign_epi8(x, x);
+ const __m256i sy = _mm256_sign_epi8(y, x);
+ return _mm256_maddubs_epi16(ax, sy);
+}
+
// spread 32 bits to 32 bytes { 0x00, 0xFF }
static inline __m256i bytes_from_bits_32(const uint8_t * x) {
uint32_t x32;
return _mm256_set_m128(_mm_set1_ps(GGML_CPU_FP16_TO_FP32(x1) * GGML_CPU_FP16_TO_FP32(y1)),
_mm_set1_ps(GGML_CPU_FP16_TO_FP32(x0) * GGML_CPU_FP16_TO_FP32(y0)));
}
+
+static inline __m256 quad_mx_delta_float(const int8_t x0, const float y0, const int8_t x1, const float y1) {
+ return _mm256_set_m128(_mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x1) * GGML_CPU_FP16_TO_FP32(y1)),
+ _mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x0) * GGML_CPU_FP16_TO_FP32(y0)));
+}
#endif
#elif defined(__SSSE3__)
// horizontally add 4x4 floats
#endif
}
+void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
+ assert(n % QK_MXFP4 == 0);
+ static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
+
+ const block_mxfp4 * GGML_RESTRICT x = vx;
+ const block_q8_0 * GGML_RESTRICT y = vy;
+
+ const int nb = n / QK_MXFP4;
+
+ int ib = 0;
+ float sumf = 0;
+
+#if defined __AVX2__
+
+ const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
+ const __m128i m4b = _mm_set1_epi8(0x0f);
+ const __m256i mone = _mm256_set1_epi16(1);
+
+ __m256 accum1 = _mm256_setzero_ps();
+ __m256 accum2 = _mm256_setzero_ps();
+ for (; ib + 1 < nb; ib += 2) {
+ const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs);
+ const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs);
+ const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib + 0].qs);
+ const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib + 1].qs);
+ const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
+ _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
+ const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
+ _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
+ const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
+ const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
+ const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
+ const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);
+ accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 0].e)),
+ _mm256_cvtepi32_ps(p_1), accum1);
+ accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 1].e)),
+ _mm256_cvtepi32_ps(p_2), accum2);
+ }
+
+ sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
+
+#elif defined __AVX__
+ const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
+ const __m128i m4b = _mm_set1_epi8(0x0f);
+
+ __m256 accum = _mm256_setzero_ps();
+ for (; ib + 1 < nb; ib += 2) {
+ const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
+ const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
+ const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);
+ const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);
+ const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
+ const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
+
+ const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
+ const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
+ const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));
+ const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));
+
+ const __m256 p = mul_sum_i8_quad_float(q4b_1_0, q4b_1_1, q4b_2_0, q4b_2_1, q8b_1_0, q8b_1_1, q8b_2_0, q8b_2_1);
+ const __m256 deltas = quad_mx_delta_float(x[ib].e, y[ib].d, x[ib + 1].e, y[ib + 1].d);
+ accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);
+ }
+
+ sumf = hsum_float_8(accum);
+
+#endif
+ for (; ib < nb; ++ib) {
+ const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
+ int sumi1 = 0;
+ int sumi2 = 0;
+ for (int j = 0; j < QK_MXFP4/2; ++j) {
+ sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
+ sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4];
+ }
+ sumf += d * (sumi1 + sumi2);
+ }
+ *s = sumf;
+}
+
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
const int qk = QK8_0;
const int nb = n / qk;
#endif
}
-#if defined(__AVX2__)
-static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
- const __m256i ax = _mm256_sign_epi8(x, x);
- const __m256i sy = _mm256_sign_epi8(y, x);
- return _mm256_maddubs_epi16(ax, sy);
-}
-#endif
-
void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);
.vec_dot_type = GGML_TYPE_Q8_1,
.nrows = 1,
},
+ [GGML_TYPE_MXFP4] = {
+ .from_float = quantize_row_mxfp4,
+ .vec_dot = ggml_vec_dot_mxfp4_q8_0,
+ .vec_dot_type = GGML_TYPE_Q8_0,
+ .nrows = 1,
+ },
[GGML_TYPE_Q2_K] = {
.from_float = quantize_row_q2_K,
.vec_dot = ggml_vec_dot_q2_K_q8_K,
{
ggml_compute_forward_add(params, tensor);
} break;
+ case GGML_OP_ADD_ID:
+ {
+ ggml_compute_forward_add_id(params, tensor);
+ } break;
case GGML_OP_ADD1:
{
ggml_compute_forward_add1(params, tensor);
} break;
case GGML_OP_FLASH_ATTN_EXT:
{
- ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
+ ggml_compute_forward_flash_attn_ext(params, tensor);
} break;
case GGML_OP_FLASH_ATTN_BACK:
{
case GGML_OP_DUP:
case GGML_OP_CONT:
case GGML_OP_ADD:
+ case GGML_OP_ADD_ID:
case GGML_OP_ADD1:
case GGML_OP_ACC:
{
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_SWIGLU:
+ case GGML_GLU_OP_SWIGLU_OAI:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
{
}
} break;
case GGML_OP_ADD:
+ case GGML_OP_ADD_ID:
case GGML_OP_ADD1:
{
if (ggml_is_quantized(node->src[0]->type)) {
#include "vec.h"
#include <float.h>
+#include <algorithm>
// ggml_compute_forward_dup
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
+ case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
}
}
+// ggml_compute_forward_add_id
+
+static void ggml_compute_forward_add_id_f32(
+ const ggml_compute_params * params,
+ ggml_tensor * dst) {
+
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+ const ggml_tensor * src2 = dst->src[2];
+
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT(src2->type == GGML_TYPE_I32);
+
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
+ GGML_ASSERT(src1->nb[0] == sizeof(float));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nr = ggml_nrows(src0);
+
+ GGML_TENSOR_TERNARY_OP_LOCALS
+
+ GGML_ASSERT( nb0 == sizeof(float));
+ GGML_ASSERT(nb10 == sizeof(float));
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int ir = ir0; ir < ir1; ++ir) {
+ // src0 indices
+ const int i3 = ir/(ne2*ne1);
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+ // src1 indices
+ const int i11 = *(int32_t *) ((char *) src2->data + i1*nb20 + i2*nb21);
+
+ GGML_ASSERT(i11 >= 0 && i11 < ne11);
+
+ ggml_vec_add_f32(ne0,
+ (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
+ (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
+ (float *) ((char *) src1->data + i11*nb11));
+ }
+}
+
+void ggml_compute_forward_add_id(
+ const ggml_compute_params * params,
+ ggml_tensor * dst) {
+
+ const ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_add_id_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ABORT("unsupported type for ggml_compute_forward_add_id: %s", ggml_type_name(src0->type));
+ }
+ }
+}
+
// ggml_compute_forward_add1
static void ggml_compute_forward_add1_f32(
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
+ case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
+ case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
}
}
+// ggml_compute_forward_swiglu_oai
+
+static void ggml_compute_forward_swiglu_oai_f32(
+ const ggml_compute_params * params,
+ ggml_tensor * dst) {
+
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+ char * src0_d = (char *) src0->data;
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
+ const size_t src0_o = src0->nb[1];
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
+
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
+
+ if (src1) {
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
+ GGML_ASSERT(src0->type == src1->type);
+ }
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
+ const int nr = ggml_nrows(src0);
+
+ GGML_ASSERT(dst->ne[0] == nc);
+ GGML_ASSERT(ggml_nrows(dst) == nr);
+
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
+ const float alpha = ggml_get_op_params_f32(dst, 2);
+ const float limit = ggml_get_op_params_f32(dst, 3);
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int i1 = ir0; i1 < ir1; i1++) {
+ float * src0_p = (float *) (src0_d + i1*src0_o);
+ float * src1_p = (float *) (src1_d + i1*src1_o);
+ float * dst_p = (float *) ((char *) dst->data + i1*(dst->nb[1]));
+
+ if (!src1) {
+ src0_p += swapped ? nc : 0;
+ src1_p += swapped ? 0 : nc;
+ }
+
+ for (int k = 0; k < nc; k++) {
+ const float x = std::min(src0_p[k], limit);
+ const float y = std::clamp(src1_p[k], -limit, limit);
+ const float out_glu = x / (1.f + expf(alpha * (-x)));
+ dst_p[k] = out_glu * (y + 1.f);
+ }
+
+#ifndef NDEBUG
+ for (int k = 0; k < nc; k++) {
+ const float x = dst_p[k];
+ GGML_UNUSED(x);
+ assert(!isnan(x));
+ assert(!isinf(x));
+ }
+#endif
+ }
+}
+
+static void ggml_compute_forward_swiglu_oai(
+ const ggml_compute_params * params,
+ ggml_tensor * dst) {
+
+ const ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_swiglu_oai_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ABORT("fatal error");
+ }
+ }
+}
+
// ggml_compute_forward_geglu_erf
static void ggml_compute_forward_geglu_erf_f32(
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
+ case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
+ case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
+ case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
+ const ggml_tensor * src2 = dst->src[2];
assert(ggml_is_contiguous(dst));
assert(ggml_are_same_shape(src0, dst));
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
+ // sinks
+ const float * sk = src2 ? (float *)((char *) src2->data) : nullptr;
+
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
float max = -INFINITY;
ggml_vec_max_f32(ne00, &max, wp);
+ // if we have sinks, make a correction as if they were included in the softmax
+ if (sk) {
+ max = MAX(max, sk[i02]);
+ }
+
ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
assert(sum > 0.0);
+ if (sk) {
+ sum += (ggml_float) expf(sk[i02] - max);
+ }
+
sum = 1.0/sum;
ggml_vec_scale_f32(ne00, dp, sum);
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
+ case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
static void ggml_compute_forward_flash_attn_ext_f16(
const ggml_compute_params * params,
- const ggml_tensor * q,
- const ggml_tensor * k,
- const ggml_tensor * v,
- const ggml_tensor * mask,
ggml_tensor * dst) {
+ const ggml_tensor * q = dst->src[0];
+ const ggml_tensor * k = dst->src[1];
+ const ggml_tensor * v = dst->src[2];
+ const ggml_tensor * mask = dst->src[3];
+ const ggml_tensor * sinks = dst->src[4];
+
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
}
}
+ // sinks
+ if (sinks) {
+ const float s = ((float *)((char *) sinks->data))[h];
+
+ float ms = 1.0f;
+ float vs = 1.0f;
+
+ if (s > M) {
+ ms = expf(M - s);
+ ggml_vec_scale_f32(DV, VKQ32, ms);
+ } else {
+ vs = expf(s - M);
+ }
+
+ S = S*ms + vs;
+ }
+
// V /= S
const float S_inv = 1.0f/S;
ggml_vec_scale_f32(DV, VKQ32, S_inv);
void ggml_compute_forward_flash_attn_ext(
const ggml_compute_params * params,
- const ggml_tensor * q,
- const ggml_tensor * k,
- const ggml_tensor * v,
- const ggml_tensor * mask,
ggml_tensor * dst) {
switch (dst->op_params[3]) {
case GGML_PREC_DEFAULT:
case GGML_PREC_F32:
{
// uses F32 accumulators
- ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
+ ggml_compute_forward_flash_attn_ext_f16(params, dst);
} break;
default:
{
{
ggml_compute_forward_swiglu(params, dst);
} break;
+ case GGML_GLU_OP_SWIGLU_OAI:
+ {
+ ggml_compute_forward_swiglu_oai(params, dst);
+ } break;
case GGML_GLU_OP_GEGLU_ERF:
{
ggml_compute_forward_geglu_erf(params, dst);
void ggml_compute_forward_dup(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_add(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_add_id(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_add1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_acc(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_sum(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
-void ggml_compute_forward_flash_attn_ext(
- const struct ggml_compute_params * params,
- const struct ggml_tensor * q,
- const struct ggml_tensor * k,
- const struct ggml_tensor * v,
- const struct ggml_tensor * mask,
- struct ggml_tensor * dst);
+void ggml_compute_forward_flash_attn_ext(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_flash_attn_back(
const struct ggml_compute_params * params,
const bool masked,
quantize_row_q8_1_ref(x, y, k);
}
+void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
+ quantize_row_mxfp4_ref(x, y, k);
+}
+
//
// 2-6 bit quantization in super-blocks
//
*s = sumf;
}
+void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
+ assert(n % QK_MXFP4 == 0);
+ static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
+
+ const block_mxfp4 * GGML_RESTRICT x = vx;
+ const block_q8_0 * GGML_RESTRICT y = vy;
+
+ const int nb = n / QK_MXFP4;
+
+ int ib = 0;
+ float sumf = 0;
+
+ for (; ib < nb; ++ib) {
+ const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
+
+ int sumi1 = 0;
+ int sumi2 = 0;
+ for (int j = 0; j < QK_MXFP4/2; ++j) {
+ sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
+ sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4];
+ }
+ sumf += d * (sumi1 + sumi2);
+ }
+ *s = sumf;
+}
+
void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
const int qk = QK8_0;
const int nb = n / qk;
void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+
void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q5_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
+void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
void ggml_vec_dot_tq1_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_tq2_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
void ggml_vec_dot_q2_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q3_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q4_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const ggml_fp16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
-inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
+
+inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) {
+ int i = 0;
+#if defined(__AVX2__)
+ for (; i + 7 < n; i += 8) {
+ __m256 vx = _mm256_loadu_ps(x + i);
+ __m256 vy = _mm256_loadu_ps(y + i);
+ __m256 vz = _mm256_add_ps(vx, vy);
+ _mm256_storeu_ps(z + i, vz);
+ }
+#endif
+ for (; i < n; ++i) {
+ z[i] = x[i] + y[i];
+ }
+}
+
inline static void ggml_vec_add_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
for (int i = 0; i < n; ++i) {
z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) + GGML_CPU_FP16_TO_FP32(y[i]));
inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
for (int i = 0; i < n; ++i) {
- float v = GGML_CPU_FP16_TO_FP32(x[i]);
- float w = GGML_CPU_FP16_TO_FP32(g[i]);
- y[i] = GGML_CPU_FP32_TO_FP16((v/(1.0f + expf(-v))) * w);
+ float xi = GGML_CPU_FP16_TO_FP32(x[i]);
+ float gi = GGML_CPU_FP16_TO_FP32(g[i]);
+ y[i] = GGML_CPU_FP32_TO_FP16((xi/(1.0f + expf(-xi))) * gi);
}
}
--- /dev/null
+#include "add-id.cuh"
+
+static __global__ void add_id_kernel(
+ const float * src0, const float * src1, const int32_t * src2, float * dst,
+ int64_t ne0, int64_t ne1,
+ size_t nb01, size_t nb02,
+ size_t nb11,
+ size_t nb21
+ ) {
+
+ const int64_t i1 = blockIdx.x;
+ const int64_t i2 = blockIdx.y;
+
+ const int i11 = *(int32_t *) ((char *) src2 + i1*sizeof(int32_t) + i2*nb21);
+
+ const size_t nb1 = ne0 * sizeof(float);
+ const size_t nb2 = ne1 * nb1;
+
+ float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2);
+ const float * src0_row = (const float *)((char *)src0 + i1*nb01 + i2*nb02);
+ const float * src1_row = (const float *)((char *)src1 + i11*nb11);
+
+ for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
+ dst_row[i0] = src0_row[i0] + src1_row[i0];
+ }
+}
+
+void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+ const ggml_tensor * src2 = dst->src[2];
+
+ GGML_TENSOR_TERNARY_OP_LOCALS
+
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT(src2->type == GGML_TYPE_I32);
+
+ GGML_ASSERT(nb00 == sizeof(float));
+ GGML_ASSERT(nb10 == sizeof(float));
+ GGML_ASSERT(nb20 == sizeof(int32_t));
+
+ const float * src0_d = (const float *)src0->data;
+ const float * src1_d = (const float *)src1->data;
+ const int32_t * src2_d = (const int32_t *)src2->data;
+ float * dst_d = (float *)dst->data;
+
+ int threads = std::min((int)ne00, 768); // cols
+ dim3 blocks(ne01, ne02); // n_experts_used, n_tokens
+ add_id_kernel<<<blocks, threads, 0, ctx.stream()>>>(
+ src0_d, src1_d, src2_d, dst_d,
+ ne0, ne1,
+ nb01, nb02,
+ nb11,
+ nb21
+ );
+}
--- /dev/null
+#include "common.cuh"
+
+void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
#pragma once
#include "ggml.h"
+#include "ggml-impl.h"
#include "ggml-cuda.h"
#include <cstdint>
#endif // defined(GGML_USE_HIP)
}
+static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
+#if CUDART_VERSION >= 12080
+ const nv_bfloat16 e = __nv_cvt_e8m0_to_bf16raw(x);
+ return (float) e;
+#else
+ uint32_t bits;
+ if (x == 0) {
+ bits = 0x00400000;
+ } else {
+ bits = (uint32_t) x << 23;
+ }
+
+ float result;
+ memcpy(&result, &bits, sizeof(float));
+ return result;
+#endif // CUDART_VERSION >= 12050
+}
+
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
static __device__ __forceinline__ float get_alibi_slope(
static constexpr int qi = QI8_0;
};
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_MXFP4> {
+ static constexpr int qk = QK_MXFP4;
+ static constexpr int qr = QR_MXFP4;
+ static constexpr int qi = QI_MXFP4;
+};
+
template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q2_K> {
static constexpr int qk = QK_K;
}
}
+template<typename dst_t>
+static __global__ void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+ const int64_t i = blockIdx.x;
+ const block_mxfp4 * x = (const block_mxfp4 *) vx + i*(QK_K/QK_MXFP4);
+
+ const int64_t tid = threadIdx.x;
+ const int64_t il = tid/8; // 0...3
+ const int64_t ib = tid%8; // 0...7
+ dst_t * y = yy + i*QK_K + 32*ib + 4*il;
+ const uint8_t * q4 = x[ib].qs + 4*il;
+ const float d = ggml_cuda_e8m0_to_fp32(x[ib].e);
+ for (int j = 0; j < 4; ++j) {
+ y[j+ 0] = d * kvalues_mxfp4[q4[j] & 0xf]*0.5f;
+ y[j+16] = d * kvalues_mxfp4[q4[j] >> 4]*0.5f;
+ }
+}
+
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static void dequantize_block_cuda(const void * vx, dst_t * y,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
}
+template<typename dst_t>
+static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb = (k + QK_K - 1) / QK_K;
+ dequantize_block_mxfp4<<<nb, 32, 0, stream>>>(vx, y);
+}
+
template <typename src_t, typename dst_t>
static __global__ void convert_unary(
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
return dequantize_row_iq4_xs_cuda;
case GGML_TYPE_IQ3_S:
return dequantize_row_iq3_s_cuda;
+ case GGML_TYPE_MXFP4:
+ return dequantize_row_mxfp4_cuda;
case GGML_TYPE_F32:
return convert_unary_cont_cuda<float>;
case GGML_TYPE_BF16:
return dequantize_row_iq4_xs_cuda;
case GGML_TYPE_IQ3_S:
return dequantize_row_iq3_s_cuda;
+ case GGML_TYPE_MXFP4:
+ return dequantize_row_mxfp4_cuda;
case GGML_TYPE_F16:
return convert_unary_cont_cuda<half>;
case GGML_TYPE_BF16:
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
+ const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
GGML_ASSERT(V || is_mla);
- const ggml_tensor * mask = dst->src[3];
+ const ggml_tensor * mask = dst->src[3];
+ const ggml_tensor * sinks = dst->src[4];
ggml_tensor * KQV = dst;
K_data,
V_data,
mask ? ((const char *) mask->data) : nullptr,
+ sinks ? ((const char *) sinks->data) : nullptr,
KV_max.ptr,
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
+ const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
// kb0 == k start index when in the output tile.
int kb0_start = kbc % iter_k;
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
+
while (kbc < kbc_stop && kb0_stop == iter_k) {
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
#else
- GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+ GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
+ const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
}
}
#else
- GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+ GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
+ const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
return;
#endif // FP16_MMA_AVAILABLE
if (use_logit_softcap && !(D == 128 || D == 256)) {
- GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+ GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
+ const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
K += nb13*sequence + nb12*(head / gqa_ratio);
V += nb23*sequence + nb22*(head / gqa_ratio);
- const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
+ const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
+ const float * sinksf = (const float *) (sinks);
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);
half2 * KQ2 = (half2 *) KQ;
half kqmax[ncols];
+ half kqsum[ncols];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
kqmax[j] = -HALF_MAX_HALF;
+ kqsum[j] = 0.0f;
}
- half kqsum[ncols] = {0.0f};
__shared__ half kqmax_shared[ncols][WARP_SIZE];
__shared__ half kqsum_shared[ncols][WARP_SIZE];
__syncthreads();
}
+ if (sinksf && blockIdx.y == 0) {
+ const half sink = __float2half(sinksf[head]);
+
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ if (threadIdx.x == 0) {
+ kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink);
+ }
+ }
+
+ __syncthreads();
+
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ half kqmax_new_j = kqmax_shared[j][threadIdx.x];
+ kqmax_new_j = warp_reduce_max(kqmax_new_j);
+
+ const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
+ kqmax[j] = kqmax_new_j;
+
+ const half val = hexp(sink - kqmax[j]);
+ kqsum[j] = kqsum[j]*KQ_max_scale;
+
+ if (tid == 0) {
+ kqsum[j] += val;
+ }
+
+ VKQ[j] *= __half2half2(KQ_max_scale);
+ }
+
+ __syncthreads();
+ }
+
#pragma unroll
for (int j = 0; j < ncols; ++j) {
kqsum[j] = warp_reduce_sum((float)kqsum[j]);
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
}
#else
- GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+ GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
+ const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
K += nb13*sequence + nb12*(head / gqa_ratio);
V += nb23*sequence + nb22*(head / gqa_ratio);
- const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
+ const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
+ const float * sinksf = (const float *) (sinks);
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
}
float kqmax[ncols];
+ float kqsum[ncols];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
kqmax[j] = -FLT_MAX/2.0f;
+ kqsum[j] = 0.0f;
}
- float kqsum[ncols] = {0.0f};
__shared__ float kqmax_shared[ncols][WARP_SIZE];
__shared__ float kqsum_shared[ncols][WARP_SIZE];
__syncthreads();
}
+ if (sinksf && blockIdx.y == 0) {
+ const float sink = sinksf[head];
+
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ if (threadIdx.x == 0) {
+ kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink);
+ }
+ }
+
+ __syncthreads();
+
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ float kqmax_new_j = kqmax_shared[j][threadIdx.x];
+ kqmax_new_j = warp_reduce_max(kqmax_new_j);
+
+ const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j);
+ kqmax[j] = kqmax_new_j;
+
+ const float val = expf(sink - kqmax[j]);
+ kqsum[j] = kqsum[j]*KQ_max_scale;
+
+ if (tid == 0) {
+ kqsum[j] += val;
+ }
+
+ VKQ[j] *= KQ_max_scale;
+ }
+
+ __syncthreads();
+ }
+
#pragma unroll
for (int j = 0; j < ncols; ++j) {
kqsum[j] = warp_reduce_sum(kqsum[j]);
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
+ const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
dst_meta[j_dst_unrolled] = dst_meta_val;
}
#else
- GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+ GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
}
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
- const ggml_tensor * KQV = dst;
- const ggml_tensor * Q = dst->src[0];
- const ggml_tensor * K = dst->src[1];
- const ggml_tensor * V = dst->src[2];
- const ggml_tensor * mask = dst->src[3];
+ const ggml_tensor * KQV = dst;
+ const ggml_tensor * Q = dst->src[0];
+ const ggml_tensor * K = dst->src[1];
+ const ggml_tensor * V = dst->src[2];
+ const ggml_tensor * mask = dst->src[3];
+ const ggml_tensor * sinks = dst->src[4];
ggml_cuda_set_device(ctx.device);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
+ // TODO: currently only vec implementation for sinks is supported [TAG_ATTN_SINKS]
+ if (sinks) {
+ if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
+ ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
+ } else {
+ ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
+ }
+ return;
+ }
+
#if defined(GGML_HIP_ROCWMMA_FATTN)
if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) {
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
#include "ggml-cuda/common.cuh"
#include "ggml-cuda/acc.cuh"
+#include "ggml-cuda/add-id.cuh"
#include "ggml-cuda/arange.cuh"
#include "ggml-cuda/argmax.cuh"
#include "ggml-cuda/argsort.cuh"
case GGML_OP_ADD1: // TODO: more efficient implementation
ggml_cuda_op_add(ctx, dst);
break;
+ case GGML_OP_ADD_ID:
+ ggml_cuda_op_add_id(ctx, dst);
+ break;
case GGML_OP_SUB:
ggml_cuda_op_sub(ctx, dst);
break;
case GGML_GLU_OP_SWIGLU:
ggml_cuda_op_swiglu(ctx, dst);
break;
+ case GGML_GLU_OP_SWIGLU_OAI:
+ ggml_cuda_op_swiglu_oai(ctx, dst);
+ break;
case GGML_GLU_OP_GEGLU_ERF:
ggml_cuda_op_geglu_erf(ctx, dst);
break;
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
+ const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
+ const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
+ const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
#endif
}
- if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1 && (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) && (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true)) {
+ if (node->op == GGML_OP_ADD &&
+ node->src[1] && node->src[1]->ne[1] > 1 &&
+ (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
+ (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
+ strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
+ strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
+ strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0) {
// disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
// by means of matching node names. See
// https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_SWIGLU:
+ case GGML_GLU_OP_SWIGLU_OAI:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
return ggml_is_contiguous_1(op->src[0]);
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
+ case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
case GGML_OP_ADD:
+ case GGML_OP_ADD_ID:
case GGML_OP_ADD1:
case GGML_OP_SUB:
case GGML_OP_MUL:
const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
}
+ // TODO: more general-purpose attention sink support [TAG_ATTN_SINKS]
+ if (op->src[4] && op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) { // currently only sinks for head_size 64 and 128 are supported
+ return false;
+ }
if (op->src[0]->ne[0] == 192) {
return false;
}
#include "im2col.cuh"
-#define MIN(a, b) (a) < (b) ? (a) : (b)
-
#define MAX_GRIDDIM_Z 65535
template <typename T>
dst[offset_dst] = x[offset_src + iih * IW + iiw];
}
}
+
+ GGML_UNUSED(IC);
+ GGML_UNUSED(KH);
}
// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
case GGML_TYPE_Q8_0:
mul_mat_q_case<GGML_TYPE_Q8_0>(ctx, args, stream);
break;
+ case GGML_TYPE_MXFP4:
+ mul_mat_q_case<GGML_TYPE_MXFP4>(ctx, args, stream);
+ break;
case GGML_TYPE_Q2_K:
mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream);
break;
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
+ case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
return MMQ_Q8_1_DS_LAYOUT_DS4;
case GGML_TYPE_Q8_0:
return MMQ_Q8_1_DS_LAYOUT_D4;
+ case GGML_TYPE_MXFP4:
+ return MMQ_Q8_1_DS_LAYOUT_D4;
case GGML_TYPE_Q2_K:
return MMQ_Q8_1_DS_LAYOUT_D2S6;
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1;
case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0;
+ case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1;
case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K;
case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K;
case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K;
case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
+ case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
}
}
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_mxfp4(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4);
+ constexpr int nrows = warp_size / threads_per_row;
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
+ const int kbx = txi / QI_MXFP4;
+ const int kqsx = txi % QI_MXFP4;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbx;
+
+ const int aux_q4 = get_int_b1(bxi->qs, kqsx);
+ const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
+ const int k0 = kbx * (2 * QI_MXFP4) + kqsx;
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0] = v.x;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y;
+#else
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+ }
+
+ constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4;
+ constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
+ const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd;
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
+#else
+ x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+ }
+}
+
template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx;
const int aux_q4 = get_int_b2(bxi->qs, kqsx);
- const int2 v = get_int_from_table_16(aux_q4);
+ const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
const int k0 = kbx * (2 * QI4_NL) + kqsx;
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;
const int aux_q4 = get_int_b4(bxi->qs, kqsx);
- const int2 v = get_int_from_table_16(aux_q4);
+ const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
const int k0 = 8 * (kqsx / 4) + kqsx % 4;
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
};
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
+ static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
+};
+
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {
static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
+extern DECL_MMQ_CASE(GGML_TYPE_MXFP4);
extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1;
case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1;
case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1;
+ case GGML_TYPE_MXFP4: return vec_dot_mxfp4_q8_1;
case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1;
case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1;
case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1;
case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ;
case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ;
case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ;
+ case GGML_TYPE_MXFP4: return VDR_MXFP4_Q8_1_MMVQ;
case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ;
case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ;
case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ;
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
break;
+ case GGML_TYPE_MXFP4:
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_MXFP4>
+ (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ stream);
+ break;
case GGML_TYPE_Q2_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
#endif // __clang__
template <bool use_shared, int ncols_template, int block_size_template, typename T>
static __global__ void soft_max_f32(
- const float * x, const T * mask, float * dst, const soft_max_params p) {
+ const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params p) {
const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
const int tid = threadIdx.x;
// shared memory buffer to cache values between iterations:
float * vals = use_shared ? buf_iw + WARP_SIZE : dst;
- float max_val = -INFINITY;
+ float max_val = sinks ? sinks[i02] : -INFINITY;
#pragma unroll
for (int col0 = 0; col0 < ncols; col0 += block_size) {
tmp = warp_reduce_sum(tmp);
}
+ if (sinks) {
+ tmp += expf(sinks[i02] - max_val);
+ }
+
const float inv_sum = 1.0f / tmp;
#pragma unroll
}
template<int... Ns, typename T>
-static void launch_soft_max_kernels(const float * x, const T * mask, float * dst,
+static void launch_soft_max_kernels(const float * x, const T * mask, const float * sinks, float * dst,
const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
{
const int id = ggml_cuda_get_device();
if (p.ncols == ncols) {
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
- (x, mask, dst, p);
+ (x, mask, sinks, dst, p);
return true;
}
return false;
//default case
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
- soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, dst, p);
+ soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, sinks, dst, p);
}
template<typename T>
-static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) {
+static void soft_max_f32_cuda(const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params & params, cudaStream_t stream) {
int nth = WARP_SIZE;
const int64_t ncols_x = params.ncols;
if (nbytes_shared <= smpbo) {
- launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, dst, params, stream, block_dims, block_nums, nbytes_shared);
+ launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared);
} else {
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
- soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, params);
+ soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, sinks, dst, params);
}
}
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
+ const ggml_tensor * src2 = dst->src[2];
const float * src0_d = (const float *) src0->data;
const void * src1_d = src1 ? (const void *) src1->data : nullptr;
+ const void * src2_d = src2 ? (const void *) src2->data : nullptr;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
params.m1 = m1;
if (use_f16) {
- soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, params, stream);
+ soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream);
} else {
- soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, params, stream);
+ soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream);
}
}
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_MXFP4);
ggml_cuda_op_unary_gated<op_gelu_quick>(ctx, dst);
}
+// swiglu_oai
+
+template <typename T>
+static __global__ void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, float alpha, float limit) {
+ const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+
+ // perform base op and multiply with gate (either offset in same tensor or a separate one)
+ const int64_t j0 = (i / n) * o0 + (i % n);
+ const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
+
+ float xi = x[j0];
+ float gi = g[j1];
+ xi = fminf(xi, limit);
+ gi = fmaxf(fminf(gi, limit), -limit);
+
+ float out_glu = xi / (1.0f + expf(-xi * alpha));
+ out_glu = out_glu * (1.0f + gi);
+
+ dst[i] = out_glu;
+}
+
+template <typename T>
+static void swiglu_oai_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream) {
+ const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE;
+ swiglu_oai_kernel<<<num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream>>>(x, g, dst, k, n, o0, o1, alpha, limit);
+}
+
+void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+ void * src0_d = src0->data;
+ void * src1_d = src1 ? src1->data : src0->data;
+ const int64_t src0_o = src0->nb[1];
+ const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
+ void * dst_d = dst->data;
+ const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
+ GGML_ASSERT(src0->nb[0] == ggml_element_size(src0));
+ GGML_ASSERT(ggml_is_contiguous(dst));
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(src0->type == dst->type);
+ GGML_ASSERT(dst->ne[0] == nc);
+ GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
+
+ if (src1) {
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
+ GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
+ GGML_ASSERT(src1->ne[0] == nc);
+ GGML_ASSERT(src0->type == src1->type);
+ }
+
+ //const int32_t swapped = ((const int32_t *) dst->op_params)[1];
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
+ const float alpha = ggml_get_op_params_f32(dst, 2);
+ const float limit = ggml_get_op_params_f32(dst, 3);
+
+ float * src0_p = (float *) src0_d;
+ float * src1_p = (float *) src1_d;
+
+ if (!src1) {
+ src0_p += swapped ? nc : 0;
+ src1_p += swapped ? 0 : nc;
+ }
+
+ swiglu_oai_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream);
+}
+
/* silu_back */
static __device__ __forceinline__ float op_silu_back(float grad, float x) {
void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
#pragma once
#include "common.cuh"
+
#include <cstdint>
+static __device__ __forceinline__ int get_int_b1(const void * x, const int & i32) {
+ const uint8_t * x8 = (const uint8_t *) x;
+
+ int x32 = x8[4*i32 + 0] << 0;
+ x32 |= x8[4*i32 + 1] << 8;
+ x32 |= x8[4*i32 + 2] << 16;
+ x32 |= x8[4*i32 + 3] << 24;
+
+ return x32;
+}
+
static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) {
const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment
return ((const int *) x)[i32]; // assume at least 4 byte alignment
}
+static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * table) {
+ const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
+ const int8_t * q0_8 = (const int8_t *) &q0_32;
+ const char4 val0_8 = make_char4(
+ table[q0_8[0]], table[q0_8[1]], table[q0_8[2]], table[q0_8[3]]);
+
+ const int q1_32 = (q4 >> 4) & 0x0F0F0F0F;
+ const int8_t * q1_8 = (const int8_t *) &q1_32;
+ const char4 val1_8 = make_char4(
+ table[q1_8[0]], table[q1_8[1]], table[q1_8[2]], table[q1_8[3]]);
+
+ return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
+}
+
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
return d8_1*sumf;
}
+#define VDR_MXFP4_Q8_1_MMVQ 2
+#define VDR_MXFP4_Q8_1_MMQ 4
+
+static __device__ __forceinline__ float vec_dot_mxfp4_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_mxfp4 * bq4 = (const block_mxfp4 *) vbq + kbx;
+
+ const int * q8 = (const int *) bq8_1->qs + iqs;
+
+ int sumi = 0;
+#pragma unroll
+ for (int l = 0; l < VDR_MXFP4_Q8_1_MMVQ; ++l) {
+ const int aux_q4 = get_int_b1(bq4->qs, iqs + l);
+ const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
+
+ sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
+ sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
+ }
+
+ const float d = ggml_cuda_e8m0_to_fp32(bq4->e) * 0.5f * __low2float(bq8_1->ds);
+ return d * sumi;
+}
+
#define VDR_Q2_K_Q8_1_MMVQ 1
#define VDR_Q2_K_Q8_1_MMQ 4
return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1);
}
-static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) {
- const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
- const int8_t * q0_8 = (const int8_t *) &q0_32;
- const char4 val0_8 = make_char4(
- kvalues_iq4nl[q0_8[0]], kvalues_iq4nl[q0_8[1]], kvalues_iq4nl[q0_8[2]], kvalues_iq4nl[q0_8[3]]);
-
- const int q1_32 = (q4 >> 4) & 0x0F0F0F0F;
- const int8_t * q1_8 = (const int8_t *) &q1_32;
- const char4 val1_8 = make_char4(
- kvalues_iq4nl[q1_8[0]], kvalues_iq4nl[q1_8[1]], kvalues_iq4nl[q1_8[2]], kvalues_iq4nl[q1_8[3]]);
-
- return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
-}
-
#define VDR_IQ4_NL_Q8_1_MMVQ 2
#define VDR_IQ4_NL_Q8_1_MMQ 4
#pragma unroll
for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
const int aux_q4 = get_int_b2(bq4->qs, iqs + l);
- const int2 v = get_int_from_table_16(aux_q4);
+ const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
#pragma unroll
for (int j = 0; j < 4; ++j) {
const int aux_q4 = get_int_b4(bq4->qs, iqs + j);
- const int2 v = get_int_from_table_16(aux_q4);
+ const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
const int u0 = get_int_b4(bq8_1[iqs/4].qs, j + 0);
const int u1 = get_int_b4(bq8_1[iqs/4].qs, j + 4);
#include <cuda_bf16.h>
#include <cuda_fp16.h>
+#if CUDART_VERSION >= 12050
+#include <cuda_fp8.h>
+#endif // CUDART_VERSION >= 12050
+
#if CUDART_VERSION < 11020
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
#define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
+static inline float ggml_e8m0_to_fp32(uint8_t x) {
+ uint32_t bits; // Stores the raw bit representation of the float
+
+ // Handle special case for minimum exponent (denormalized float)
+ if (x == 0) {
+ // Bit pattern for 2^(-127):
+ // - Sign bit: 0 (positive)
+ // - Exponent: 0 (denormalized number)
+ // - Mantissa: 0x400000 (0.5 in fractional form)
+ // Value = 0.5 * 2^(-126) = 2^(-127)
+ bits = 0x00400000;
+ }
+ // note: disabled as we don't need to handle NaNs
+ //// Handle special case for NaN (all bits set)
+ //else if (x == 0xFF) {
+ // // Standard quiet NaN pattern:
+ // // - Sign bit: 0
+ // // - Exponent: all 1s (0xFF)
+ // // - Mantissa: 0x400000 (quiet NaN flag)
+ // bits = 0x7FC00000;
+ //}
+ // Normalized values (most common case)
+ else {
+ // Construct normalized float by shifting exponent into position:
+ // - Exponent field: 8 bits (positions 30-23)
+ // - Mantissa: 0 (implicit leading 1)
+ // Value = 2^(x - 127)
+ bits = (uint32_t) x << 23;
+ }
+
+ float result; // Final float value
+ // Safely reinterpret bit pattern as float without type-punning issues
+ memcpy(&result, &bits, sizeof(float));
+ return result;
+}
+
+// Equal to ggml_e8m0_to_fp32/2
+// Useful with MXFP4 quantization since the E0M2 values are doubled
+static inline float ggml_e8m0_to_fp32_half(uint8_t x) {
+ uint32_t bits;
+
+ // For x < 2: use precomputed denormal patterns
+ if (x < 2) {
+ // 0x00200000 = 2^(-128), 0x00400000 = 2^(-127)
+ bits = 0x00200000 << x;
+ }
+ // For x >= 2: normalized exponent adjustment
+ else {
+ // 0.5 * 2^(x-127) = 2^(x-128) = normalized with exponent (x-1)
+ bits = (uint32_t)(x - 1) << 23;
+ }
+ // Note: NaNs are not handled here
+
+ float result;
+ memcpy(&result, &bits, sizeof(float));
+ return result;
+}
+
+#define GGML_E8M0_TO_FP32(x) ggml_e8m0_to_fp32(x)
+#define GGML_E8M0_TO_FP32_HALF(x) ggml_e8m0_to_fp32_half(x)
+
/**
* Converts brain16 to float32.
*
#define N_R0_Q8_0 4
#define N_SG_Q8_0 2
+#define N_R0_MXFP4 2
+#define N_SG_MXFP4 2
+
#define N_R0_Q2_K 4
#define N_SG_Q2_K 2
uint64_t o1[8];
} ggml_metal_kargs_bin;
+typedef struct {
+ int64_t ne0;
+ int64_t ne1;
+ size_t nb01;
+ size_t nb02;
+ size_t nb11;
+ size_t nb21;
+} ggml_metal_kargs_add_id;
+
typedef struct {
int32_t ne00;
int32_t ne01;
uint64_t nb1;
int32_t i00;
int32_t i10;
+ float alpha;
+ float limit;
} ggml_metal_kargs_glu;
typedef struct {
GGML_METAL_KERNEL_TYPE_MUL_ROW_C4,
GGML_METAL_KERNEL_TYPE_DIV,
GGML_METAL_KERNEL_TYPE_DIV_ROW_C4,
+ GGML_METAL_KERNEL_TYPE_ADD_ID,
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
GGML_METAL_KERNEL_TYPE_REPEAT_I32,
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4,
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16,
GGML_METAL_KERNEL_TYPE_REGLU,
GGML_METAL_KERNEL_TYPE_GEGLU,
GGML_METAL_KERNEL_TYPE_SWIGLU,
+ GGML_METAL_KERNEL_TYPE_SWIGLU_OAI,
GGML_METAL_KERNEL_TYPE_GEGLU_ERF,
GGML_METAL_KERNEL_TYPE_GEGLU_QUICK,
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, mul_row_c4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, div_row_c4, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ID, add_id, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4, get_rows_mxfp4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, mul_mv_mxfp4_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2, mul_mv_ext_mxfp4_f32_r1_2, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3, mul_mv_ext_mxfp4_f32_r1_3, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4, mul_mv_ext_mxfp4_f32_r1_4, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5, mul_mv_ext_mxfp4_f32_r1_5, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32, mul_mv_id_mxfp4_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, mul_mm_id_q5_0_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16, mul_mm_id_q5_1_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16, mul_mm_id_q8_0_f16, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16, mul_mm_id_mxfp4_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16, mul_mm_id_q2_K_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16, mul_mm_id_q3_K_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16, mul_mm_id_q4_K_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU_OAI, swiglu_oai, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF, geglu_erf, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, geglu_quick, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_SWIGLU:
+ case GGML_GLU_OP_SWIGLU_OAI:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
+ case GGML_OP_ADD_ID:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ACC:
case GGML_OP_REPEAT:
const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
+ const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT;
const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
size_t offs_src0 = 0;
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
}
} break;
+ case GGML_OP_ADD_ID:
+ {
+ GGML_ASSERT(src0t == GGML_TYPE_F32);
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
+ GGML_ASSERT(src2t == GGML_TYPE_I32);
+ GGML_ASSERT(dstt == GGML_TYPE_F32);
+
+ GGML_ASSERT(ggml_is_contiguous_rows(src0));
+
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ID].pipeline;
+
+ ggml_metal_kargs_add_id args = {
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb11 =*/ nb11,
+ /*.nb21 =*/ nb21,
+
+ };
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
+
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
case GGML_OP_REPEAT:
{
id<MTLComputePipelineState> pipeline;
case GGML_GLU_OP_SWIGLU:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
break;
+ case GGML_GLU_OP_SWIGLU_OAI:
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU_OAI].pipeline;
+ break;
case GGML_GLU_OP_GEGLU_ERF:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_ERF].pipeline;
break;
GGML_ABORT("fatal error");
}
- const int32_t swp = ((const int32_t *) dst->op_params)[1];
+ const int32_t swp = ggml_get_op_params_i32(dst, 1);
+ const float alpha = ggml_get_op_params_f32(dst, 2);
+ const float limit = ggml_get_op_params_f32(dst, 3);
const int32_t i00 = swp ? ne0 : 0;
const int32_t i10 = swp ? 0 : ne0;
/*.nb1 =*/ nb1,
/*.i00 =*/ src1 ? 0 : i00,
/*.i10 =*/ src1 ? 0 : i10,
+ /*.alpha=*/ alpha,
+ /*.limit=*/ limit
};
[encoder setComputePipelineState:pipeline];
} else {
[encoder setBuffer:h_src0 offset:offs_src0 atIndex:1];
}
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&args length:sizeof(args) atIndex:3];
+ if (id_src2) {
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
+ } else {
+ [encoder setBuffer:h_src0 offset:offs_src0 atIndex:2];
+ }
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
+ [encoder setBytes:&args length:sizeof(args) atIndex:4];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
src0t == GGML_TYPE_Q5_0 ||
src0t == GGML_TYPE_Q5_1 ||
src0t == GGML_TYPE_Q8_0 ||
+ src0t == GGML_TYPE_MXFP4 ||
src0t == GGML_TYPE_IQ4_NL ||
false) && (ne11 >= 2 && ne11 <= 8)
) ||
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break;
default: GGML_ABORT("not implemented");
} break;
+ case GGML_TYPE_MXFP4:
+ switch (r1ptg) {
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2].pipeline; break;
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3].pipeline; break;
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4].pipeline; break;
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5].pipeline; break;
+ default: GGML_ABORT("not implemented");
+ } break;
case GGML_TYPE_Q4_K:
switch (r1ptg) {
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2].pipeline; break;
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
+ case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break;
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
nr0 = N_R0_Q8_0;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
} break;
+ case GGML_TYPE_MXFP4:
+ {
+ nsg = N_SG_MXFP4;
+ nr0 = N_R0_MXFP4;
+ smem = 32*sizeof(float);
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32].pipeline;
+ } break;
case GGML_TYPE_Q2_K:
{
nsg = N_SG_Q2_K;
case GGML_OP_MUL_MAT_ID:
{
// src2 = ids
- const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
-
GGML_ASSERT(src2t == GGML_TYPE_I32);
GGML_ASSERT(!ggml_is_transposed(src0));
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16 ].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16 ].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16 ].pipeline; break;
+ case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16 ].pipeline; break;
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16 ].pipeline; break;
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16 ].pipeline; break;
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16 ].pipeline; break;
nr0 = N_R0_Q8_0;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
} break;
+ case GGML_TYPE_MXFP4:
+ {
+ nsg = N_SG_MXFP4;
+ nr0 = N_R0_MXFP4;
+ smem = 32*sizeof(float);
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32].pipeline;
+ } break;
case GGML_TYPE_Q2_K:
{
nsg = N_SG_Q2_K;
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break;
+ case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4 ].pipeline; break;
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break;
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break;
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break;
GGML_ASSERT(ne11 == ne21);
GGML_ASSERT(ne12 == ne22);
- struct ggml_tensor * src3 = node->src[3];
+ struct ggml_tensor * src3 = node->src[3]; // mask
+ struct ggml_tensor * src4 = node->src[4]; // sinks
size_t offs_src3 = 0;
+ size_t offs_src4 = 0;
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
+ id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16);
GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32);
const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33);
- const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
-
float scale;
float max_bias;
float logit_softcap;
} else {
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:4];
}
- [encoder setBuffer:id_dst offset:offs_dst atIndex:5];
+ if (id_src4) {
+ [encoder setBuffer:id_src4 offset:offs_src4 atIndex:5];
+ } else {
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:5];
+ }
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
if (!use_vec_kernel) {
// half8x8 kernel
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
};
+constexpr constant static float kvalues_mxfp4_f[16] = {
+ 0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f
+};
+
static inline int best_index_int8(int n, constant float * val, float x) {
if (x <= val[0]) return 0;
if (x >= val[n-1]) return n-1;
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
}
+static inline float e8m0_to_fp32(uint8_t x) {
+ uint32_t bits;
+
+ if (x == 0) {
+ bits = 0x00400000;
+ } else {
+ bits = (uint32_t) x << 23;
+ }
+
+ return as_type<float>(bits);
+}
+
// NOTE: this is not dequantizing - we are simply fitting the template
template <typename type4x4>
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
}
}
+void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
+#pragma METAL fp math_mode(safe)
+ float amax = 0.0f; // absolute max
+
+ for (int j = 0; j < QK8_0; j++) {
+ const float v = src[j];
+ amax = MAX(amax, fabs(v));
+ }
+
+ const float d = amax / ((1 << 7) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dst.d = d;
+
+ for (int j = 0; j < QK8_0; ++j) {
+ const float x0 = src[j]*id;
+
+ dst.qs[j] = round(x0);
+ }
+}
+
void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max
}
}
-void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
-#pragma METAL fp math_mode(safe)
- float amax = 0.0f; // absolute max
+template <typename type4x4>
+void dequantize_mxfp4(device const block_mxfp4 * xb, short il, thread type4x4 & reg) {
+ device const uint8_t * q2 = (device const uint8_t *)xb->qs;
- for (int j = 0; j < QK8_0; j++) {
- const float v = src[j];
- amax = MAX(amax, fabs(v));
+ const float d = e8m0_to_fp32(xb->e);
+ const uint8_t shr = il >= 1 ? 4 : 0;
+
+ for (int i = 0; i < 4; ++i) {
+ reg[i][0] = d * kvalues_mxfp4_f[(q2[4*i + 0] >> shr) & 0x0F];
+ reg[i][1] = d * kvalues_mxfp4_f[(q2[4*i + 1] >> shr) & 0x0F];
+ reg[i][2] = d * kvalues_mxfp4_f[(q2[4*i + 2] >> shr) & 0x0F];
+ reg[i][3] = d * kvalues_mxfp4_f[(q2[4*i + 3] >> shr) & 0x0F];
}
+}
- const float d = amax / ((1 << 7) - 1);
- const float id = d ? 1.0f/d : 0.0f;
+template <typename type4>
+void dequantize_mxfp4_t4(device const block_mxfp4 * xb, short il, thread type4 & reg) {
+ device const uint8_t * q2 = (device const uint8_t *)xb->qs;
- dst.d = d;
+ const float d = e8m0_to_fp32(xb->e);
+ const short il4 = il%4;
- for (int j = 0; j < QK8_0; ++j) {
- const float x0 = src[j]*id;
+ const uint8_t shr = il >= 4 ? 4 : 0;
- dst.qs[j] = round(x0);
- }
+ reg[0] = d * kvalues_mxfp4_f[(q2[4*il4 + 0] >> shr) & 0x0F];
+ reg[1] = d * kvalues_mxfp4_f[(q2[4*il4 + 1] >> shr) & 0x0F];
+ reg[2] = d * kvalues_mxfp4_f[(q2[4*il4 + 2] >> shr) & 0x0F];
+ reg[3] = d * kvalues_mxfp4_f[(q2[4*il4 + 3] >> shr) & 0x0F];
}
template <typename type4x4>
}
}
+kernel void kernel_add_id(
+ constant ggml_metal_kargs_add_id & args,
+ device const char * src0,
+ device const char * src1,
+ device const char * src2,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+ const int i1 = tgpig.x;
+ const int i2 = tgpig.y;
+
+ const int i11 = *((device const int32_t *) (src2 + i1*sizeof(int32_t) + i2*args.nb21));
+
+ const size_t nb1 = args.ne0 * sizeof(float);
+ const size_t nb2 = args.ne1 * nb1;
+
+ device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2);
+ device const float * src0_row = (device const float *)((device char *)src0 + i1*args.nb01 + i2*args.nb02);
+ device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);
+
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+ dst_row[i0] = src0_row[i0] + src1_row[i0];
+ }
+}
+
template<typename T>
kernel void kernel_repeat(
constant ggml_metal_kargs_repeat & args,
}
}
+kernel void kernel_swiglu_oai(
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ constant ggml_metal_kargs_glu & args,
+ uint tgpig[[threadgroup_position_in_grid]],
+ uint tpitg[[thread_position_in_threadgroup]],
+ uint ntg[[threads_per_threadgroup]]) {
+ device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
+ device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
+ device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
+
+ for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
+ float x0 = src0_row[i0];
+ float x1 = src1_row[i0];
+
+ x0 = min(x0, args.limit);
+ x1 = max(min(x1, args.limit), -args.limit);
+
+ float out_glu = x0 / (1.0f + exp(-x0 * args.alpha));
+ out_glu = out_glu * (1.0f + x1);
+
+ dst_row[i0] = out_glu;
+ }
+}
+
kernel void kernel_geglu_erf(
device const char * src0,
device const char * src1,
kernel void kernel_soft_max(
device const char * src0,
device const char * src1,
+ device const char * src2,
device char * dst,
constant ggml_metal_kargs_soft_max & args,
threadgroup float * buf [[threadgroup(0)]],
device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
+ device const float * psrc2 = src2 != src0 ? (device const float *) (src2) : nullptr;
device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
float slope = 1.0f;
}
// parallel max
- float lmax = -INFINITY;
+ float lmax = psrc2 ? psrc2[i02] : -INFINITY;
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
sum = simd_sum(sum);
}
+ if (psrc2) {
+ sum += exp(psrc2[i02] - max_val);
+ }
+
const float inv_sum = 1.0f/sum;
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
kernel void kernel_soft_max_4(
device const char * src0,
device const char * src1,
+ device const char * src2,
device char * dst,
constant ggml_metal_kargs_soft_max & args,
threadgroup float * buf [[threadgroup(0)]],
device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
+ device const float * psrc2 = src2 != src0 ? (device const float * ) (src2) : nullptr;
device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
float slope = 1.0f;
}
// parallel max
- float4 lmax4 = -INFINITY;
+ float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
sum = simd_sum(sum);
}
+ if (psrc2) {
+ sum += exp(psrc2[i02] - max_val);
+ }
+
const float inv_sum = 1.0f/sum;
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q8_0, 32, dequantize_q8_0_t4>;
template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q8_0, 32, dequantize_q8_0_t4>;
+template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_mxfp4, 32, dequantize_mxfp4_t4>;
+template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_mxfp4, 32, dequantize_mxfp4_t4>;
+template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_mxfp4, 32, dequantize_mxfp4_t4>;
+template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_mxfp4, 32, dequantize_mxfp4_t4>;
+
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
device const char * k,
device const char * v,
device const char * mask,
+ device const char * sinks,
device char * dst,
threadgroup half * shmem_f16 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
}
}
+ if (sinks != q && sgitg == 0) {
+ for (ushort j = 0; j < Q; ++j) {
+ const float m = M[j];
+ const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
+
+ M[j] = simd_max(max(M[j], s));
+
+ const float ms = exp(m - M[j]);
+ const float vs = exp(s - M[j]);
+
+ S[j] = S[j]*ms + simd_sum(vs);
+
+ if (tiisg == j) {
+ ss[j*TS + 2*C + j] = ms;
+ }
+ }
+
+ // O = diag(ms)*O
+ {
+ s8x8_t ms;
+ simdgroup_load(ms, ss + 2*C, TS, 0, false);
+
+ #pragma unroll(DV8)
+ for (short i = 0; i < DV8; ++i) {
+ simdgroup_multiply(lo[i], ms, lo[i]);
+ }
+ }
+ }
+
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
for (short j = tiisg; j < Q; j += NW) {
ss[j*TS + 0] = S[j];
device const char * k,
device const char * v,
device const char * mask,
+ device const char * sinks,
device char * dst,
threadgroup half * shmem_f16 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
}
}
+ if (sinks != q && sgitg == 0) {
+ const float m = M;
+ const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
+
+ M = simd_max(max(M, s));
+
+ const float ms = exp(m - M);
+ const float vs = exp(s - M);
+
+ S = S*ms + simd_sum(vs);
+
+#pragma unroll(DV4/NL)
+ for (short ii = 0; ii < DV4; ii += NL) {
+ lo[ii/NL] *= ms;
+ }
+ }
+
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
if (tiisg == 0) {
ss[0] = (s_t) S;
kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, N_SG_IQ4_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
+template<int nr0, int nsg, int nw, typename args_t>
+void kernel_mul_mv_mxfp4_f32_impl(
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
+
+ threadgroup float * shmem_f32 = (threadgroup float *) shmem;
+ const int nb = args.ne00/QK_MXFP4;
+
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0 * nsg + sgitg) * nr0;
+
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
+
+ const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
+
+ device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0);
+ device const float * y = (device const float *) (src1 + offset1);
+
+ const short ix = tiisg/2; // 0...15
+ const short it = tiisg%2; // 0 or 1
+
+ shmem_f32[tiisg] = kvalues_mxfp4_f[tiisg%16];
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ float4 yl[4];
+ float sumf[nr0]={0.f};
+
+ device const float * yb = y + ix * QK_MXFP4 + it * 8;
+
+ for (int ib = ix; ib < nb; ib += 16) {
+ device const float4 * y4 = (device const float4 *)yb;
+ yl[0] = y4[0];
+ yl[1] = y4[4];
+ yl[2] = y4[1];
+ yl[3] = y4[5];
+
+#pragma unroll(nr0)
+ for (short row = 0; row < nr0; row++) {
+ device const block_mxfp4 & xb = x[row*nb + ib];
+ device const uint8_t * q2 = (device const uint8_t *)(xb.qs + 8*it);
+
+ float4 acc1 = yl[0]*float4(shmem_f32[q2[0] & 0x0F], shmem_f32[q2[1] & 0x0F], shmem_f32[q2[2] & 0x0F], shmem_f32[q2[3] & 0x0F]);
+ float4 acc2 = yl[1]*float4(shmem_f32[q2[0] >> 4 ], shmem_f32[q2[1] >> 4 ], shmem_f32[q2[2] >> 4 ], shmem_f32[q2[3] >> 4 ]);
+ float4 acc3 = yl[2]*float4(shmem_f32[q2[4] & 0x0F], shmem_f32[q2[5] & 0x0F], shmem_f32[q2[6] & 0x0F], shmem_f32[q2[7] & 0x0F]);
+ float4 acc4 = yl[3]*float4(shmem_f32[q2[4] >> 4 ], shmem_f32[q2[5] >> 4 ], shmem_f32[q2[6] >> 4 ], shmem_f32[q2[7] >> 4 ]);
+
+ acc1 = (acc1 + acc3) + (acc2 + acc4);
+
+ sumf[row] += e8m0_to_fp32(xb.e) * ((acc1[0] + acc1[1]) + (acc1[2] + acc1[3]));
+ }
+
+ yb += 16 * QK_MXFP4;
+ }
+
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+ float sum_all = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst_f32[first_row + row] = sum_all;
+ }
+ }
+}
+
+[[host_name("kernel_mul_mv_mxfp4_f32")]]
+kernel void kernel_mul_mv_mxfp4_f32(
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, N_SG_MXFP4, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+}
+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
kernel void kernel_get_rows_q(
constant ggml_metal_kargs_get_rows & args,
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
+template [[host_name("kernel_get_rows_mxfp4")]] kernel get_rows_q_t kernel_get_rows_q<block_mxfp4, 2, dequantize_mxfp4>;
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
+template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4>;
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
+template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4>;
template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SG_Q5_0, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SG_Q5_1, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, N_SG_MXFP4, N_SIMDWIDTH>>>;
+
template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl <N_R0_Q2_K, N_SG_Q2_K, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl <N_R0_Q3_K, N_SG_Q3_K, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl <N_R0_Q4_K, N_SG_Q4_K, N_SIMDWIDTH>>>;
case GGML_OP_CLAMP:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SOFT_MAX:
+ // TODO: support attention sinks [TAG_ATTN_SINKS]
+ return op->src[2] == nullptr;
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
return true;
#define UNUSED GGML_UNUSED
+static inline int best_index_int8(int n, const int8_t * val, float x) {
+ if (x <= val[0]) return 0;
+ if (x >= val[n-1]) return n-1;
+ int ml = 0, mu = n-1;
+ while (mu-ml > 1) {
+ int mav = (ml+mu)/2;
+ if (x < val[mav]) mu = mav; else ml = mav;
+ }
+ return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
+}
+
// reference implementation for deterministic creation of model files
void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k) {
static const int qk = QK4_0;
}
}
+static inline int best_index_mxfp4(float x, float e) {
+ int best_index = 0;
+ float best_err = fabsf(kvalues_mxfp4[0]*e - x);
+ for (int i = 1; i < 16; i++) {
+ float err = fabsf(kvalues_mxfp4[i]*e - x);
+ if (err < best_err) {
+ best_index = i;
+ best_err = err;
+ }
+ }
+ return best_index;
+}
+
+void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k) {
+ static const int qk = QK_MXFP4;
+
+ assert(k % qk == 0);
+
+ const int nb = k / qk;
+
+ for (int i = 0; i < nb; i++) {
+ float amax = 0.0f; // absolute max
+
+ for (int j = 0; j < qk; j++) {
+ const float v = x[i*qk + j];
+
+ if (amax < fabsf(v)) {
+ amax = fabsf(v);
+ }
+ }
+
+ const uint8_t e = (uint8_t) (floorf(log2f(amax)) - 2 + 127);
+
+ const float d = GGML_E8M0_TO_FP32_HALF(e);
+
+ y[i].e = e;
+
+ for (int j = 0; j < qk/2; ++j) {
+ const uint8_t x0 = best_index_mxfp4(x[i*qk + 0 + j], d);
+ const uint8_t x1 = best_index_mxfp4(x[i*qk + qk/2 + j], d);
+
+ y[i].qs[j] = x0;
+ y[i].qs[j] |= x1 << 4;
+ }
+ }
+}
+
void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
static const int qk = QK4_0;
}
}
+void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
+ static const int qk = QK_MXFP4;
+
+ assert(k % qk == 0);
+
+ const int nb = k / qk;
+
+ for (int i = 0; i < nb; i++) {
+ const float d = GGML_E8M0_TO_FP32_HALF(x[i].e);
+
+ for (int j = 0; j < qk/2; ++j) {
+ const int8_t x0 = kvalues_mxfp4[x[i].qs[j] & 0x0F];
+ const int8_t x1 = kvalues_mxfp4[x[i].qs[j] >> 4];
+
+ y[i*qk + j + 0 ] = x0*d;
+ y[i*qk + j + qk/2] = x1*d;
+ }
+ }
+}
+
//
// 2-6 bit quantization in super-blocks
//
return nrow * row_size;
}
+size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+ GGML_UNUSED(quant_weights);
+ quantize_row_mxfp4_ref(src, dst, (int64_t)nrow*n_per_row);
+ return nrow * ggml_row_size(GGML_TYPE_MXFP4, n_per_row);
+}
+
// ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k) {
// ============================ 4-bit non-linear quants
-static inline int best_index_int8(int n, const int8_t * val, float x) {
- if (x <= val[0]) return 0;
- if (x >= val[n-1]) return n-1;
- int ml = 0, mu = n-1;
- while (mu-ml > 1) {
- int mav = (ml+mu)/2;
- if (x < val[mav]) mu = mav; else ml = mav;
- }
- return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
-}
-
static void quantize_row_iq4_nl_impl(const int super_block_size, const int block_size, const float * GGML_RESTRICT x,
ggml_fp16_t * dh, uint8_t * q4, uint16_t * scales_h, uint8_t * scales_l,
float * scales, float * weight, uint8_t * L,
return true;
}
+static bool validate_e_e8m0(uint8_t e, size_t i) {
+ if (e == 0xff) {
+ fprintf(stderr, "ggml_validate_row_data: found invalid e value %d at block %zu\n", e, i);
+ return false;
+ }
+
+ return true;
+}
+
#define VALIDATE_ROW_DATA_D_F16_IMPL(type, data, nb) \
const type * q = (const type *) (data); \
for (size_t i = 0; i < (nb); ++i) { \
} \
}
+#define VALIDATE_ROW_DATA_E_E8M0_IMPL(type, data, nb) \
+ const type * q = (const type *) (data); \
+ for (size_t i = 0; i < (nb); ++i) { \
+ if (!validate_e_e8m0(q[i].e, i)) { \
+ return false; \
+ } \
+ }
+
#define VALIDATE_ROW_DATA_DVEC_F16_IMPL(type, data, nb, nr) \
const type * q = (const type *) (data); \
for (size_t i = 0; i < (nb); ++i) { \
{
VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb);
} break;
+ case GGML_TYPE_MXFP4:
+ {
+ VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp4, data, nb);
+ } break;
case GGML_TYPE_Q2_K:
{
VALIDATE_ROW_DATA_DM_F16_IMPL(block_q2_K, data, nb, d, dmin);
GGML_API void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k);
+GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k);
+
GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
//GGML_API void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+
GGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_q4_K(const block_q4_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+GGML_API size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+
GGML_API void iq2xs_init_impl(enum ggml_type type);
GGML_API void iq2xs_free_impl(enum ggml_type type);
GGML_API void iq3xs_init_impl(int grid_size);
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
{
- struct ggml_tensor * a;
- struct ggml_tensor * b;
- if (op->op == GGML_OP_MUL_MAT) {
- a = op->src[0];
- b = op->src[1];
- } else {
- a = op->src[2];
- b = op->src[1];
- }
+ struct ggml_tensor * a = op->src[0];
+ struct ggml_tensor * b = op->src[1];
+
if (a->ne[3] != b->ne[3]) {
return false;
}
}
}
ggml_type src0_type = op->src[0]->type;
- if (src0_type == GGML_TYPE_BF16) {
+ if (src0_type == GGML_TYPE_BF16 || src0_type == GGML_TYPE_MXFP4) {
+ // TODO: support MXFP4
+ // FIXME: keep a list of supported types to avoid breaking the backend when a new type is added
return false;
}
return true;
if (op->src[0]->ne[3] != 1) {
return false;
}
+ // TODO: support attention sinks [TAG_ATTN_SINKS]
+ if (op->src[2]) {
+ return false;
+ }
// TODO: support broadcast
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
vk_pipeline pipeline_div[2][2][2];
vk_pipeline pipeline_div_norepeat[2][2][2];
+ vk_pipeline pipeline_add_id_f32;
+
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32;
vk_pipeline pipeline_scale_f32;
vk_pipeline pipeline_geglu[2];
vk_pipeline pipeline_reglu[2];
vk_pipeline pipeline_swiglu[2];
+ vk_pipeline pipeline_swiglu_oai[2];
vk_pipeline pipeline_geglu_erf[2];
vk_pipeline pipeline_geglu_quick[2];
uint32_t ne00;
uint32_t ne20;
uint32_t mode; // 0: default, 1: swapped, 2: split
+ float alpha; // for swiglu_oai
+ float limit;
};
struct vk_op_unary_push_constants {
float param1; float param2; int32_t param3;
};
+struct vk_op_add_id_push_constants {
+ uint32_t ne0;
+ uint32_t ne1;
+ uint32_t s01;
+ uint32_t s02;
+ uint32_t s11;
+ uint32_t s21;
+};
+
struct vk_op_diag_mask_push_constants {
uint32_t ncols;
uint32_t rows_per_channel;
float m1;
uint32_t n_head_log2;
uint32_t nrows_x;
+ uint32_t has_sinks;
};
struct vk_op_argsort_push_constants {
break;
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
+ case GGML_TYPE_MXFP4:
lut_size = 4*16;
break;
default:
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S], matmul_iq3_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_MXFP4], matmul_mxfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
#undef CREATE_MM
#undef CREATE_MM2
} else
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+ CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
} else {
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+ CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
}
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
} else {
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
}
#undef CREATE_MM2
#undef CREATE_MM
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+ CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
#undef CREATE_MM2
#undef CREATE_MMQ
#undef CREATE_MM
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+ CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
}
// reusing CREATE_MM from the fp32 path
if ((device->coopmat2 || device->coopmat_support)
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f32_f32_len, mul_mat_vec_iq4_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32_"+std::to_string(i+1), mul_mat_vec_mxfp4_f32_f32_len, mul_mat_vec_mxfp4_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f16_f32_len, mul_mat_vec_iq4_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32_"+std::to_string(i+1), mul_mat_vec_mxfp4_f16_f32_len, mul_mat_vec_mxfp4_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
}
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", mul_mat_vec_id_mxfp4_f32_len, mul_mat_vec_id_mxfp4_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
// dequant shaders
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_S], "dequant_iq3_s", dequant_iq3_s_len, dequant_iq3_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4], "dequant_mxfp4", dequant_mxfp4_len, dequant_mxfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
// get_rows
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_S], "get_rows_iq3_s", get_rows_iq3_s_len, get_rows_iq3_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4", get_rows_mxfp4_len, get_rows_mxfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_S], "get_rows_iq3_s_f32", get_rows_iq3_s_f32_len, get_rows_iq3_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs_f32", get_rows_iq4_xs_f32_len, get_rows_iq4_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 4 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
CREATE_BINARY(div, _norepeat, {1})
#undef CREATE_BINARY
+ ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1);
+
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
CREATE_GLU(geglu)
CREATE_GLU(reglu)
CREATE_GLU(swiglu)
+ CREATE_GLU(swiglu_oai)
CREATE_GLU(geglu_erf)
CREATE_GLU(geglu_quick)
#undef CREATE_GLU
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_MXFP4:
break;
default:
return nullptr;
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_MXFP4:
break;
default:
return nullptr;
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_MXFP4:
break;
default:
return nullptr;
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_MXFP4:
break;
default:
return nullptr;
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_MXFP4:
break;
default:
return nullptr;
std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))");
GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size());
GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT);
+ GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size());
vk::DescriptorSet& descriptor_set = ctx->descriptor_sets[ctx->descriptor_set_idx++];
vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() };
break;
}
return nullptr;
+ case GGML_OP_ADD_ID:
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && src2->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_add_id_f32;
+ }
+ return nullptr;
case GGML_OP_CONCAT:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_concat_f32;
return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
case GGML_GLU_OP_SWIGLU:
return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
+ case GGML_GLU_OP_SWIGLU_OAI:
+ return ctx->device->pipeline_swiglu_oai[dst->type == GGML_TYPE_F16];
case GGML_GLU_OP_GEGLU_ERF:
return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16];
case GGML_GLU_OP_GEGLU_QUICK:
return nullptr;
case GGML_OP_SOFT_MAX:
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
+ GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32);
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
+ case GGML_OP_ADD_ID:
case GGML_OP_CONCAT:
case GGML_OP_UPSCALE:
case GGML_OP_SQR:
elements = { ne, 1, 1 };
}
} break;
+ case GGML_OP_ADD_ID:
+ {
+ elements = { (uint32_t)ne01, (uint32_t)ne02, 1 };
+ } break;
case GGML_OP_SET_ROWS:
{
uint32_t ne = ggml_nelements(src0);
}
}
- if (op == GGML_OP_SOFT_MAX || op == GGML_OP_GLU) {
- // Empty src1 is possible in soft_max, but the shader needs a buffer
+ if (op == GGML_OP_GLU) {
+ // Empty src1 is possible in glu, but the shader needs a buffer
vk_subbuffer subbuf_y;
if (use_src1) {
subbuf_y = { d_Y, y_buf_offset, y_sz };
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
+ } else if (op == GGML_OP_SOFT_MAX) {
+ // Empty src1 and src2 is possible in soft_max, but the shader needs a buffer
+ vk_subbuffer subbuf_y;
+ if (use_src1) {
+ subbuf_y = { d_Y, y_buf_offset, y_sz };
+ } else {
+ subbuf_y = { d_X, 0, x_sz };
+ }
+
+ vk_subbuffer subbuf_z;
+ if (use_src2) {
+ subbuf_z = { d_Z, z_buf_offset, z_sz };
+ } else {
+ subbuf_z = { d_X, 0, x_sz };
+ }
+
+ ggml_vk_sync_buffers(subctx);
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
// Empty src2 is possible in rope, but the shader needs a buffer
vk_subbuffer subbuf_z;
}, dryrun);
}
+static void ggml_vk_add_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
+ const uint32_t src1_type_size = ggml_type_size(src1->type);
+ const uint32_t src2_type_size = ggml_type_size(src2->type);
+
+ ggml_vk_op_f32<vk_op_add_id_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_ADD_ID, {
+ (uint32_t)dst->ne[0],
+ (uint32_t)dst->ne[1],
+ (uint32_t)src0->nb[1] / src0_type_size,
+ (uint32_t)src0->nb[2] / src0_type_size,
+ (uint32_t)src1->nb[1] / src1_type_size,
+ (uint32_t)src2->nb[1] / src2_type_size,
+ }, dryrun);
+}
+
static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version, bool dryrun = false) {
GGML_ASSERT(version == 6 || version == 7);
int num_srcs = version == 6 ? 6 : 7;
}
static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
+ const float * op_params_f = (const float *)dst->op_params;
+
const bool swapped = (bool)dst->op_params[1];
const bool split = src1 != nullptr;
+ const float alpha = op_params_f[2];
+ const float limit = op_params_f[3];
GGML_ASSERT(ggml_is_contiguous(src0));
const uint32_t mode = split ? 2 : (swapped ? 1 : 0);
- ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, { (uint32_t)ggml_nelements(dst), (uint32_t)src0->ne[0], (uint32_t)dst->ne[0], mode }, dryrun);
+ ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU,
+ {
+ (uint32_t)ggml_nelements(dst),
+ (uint32_t)src0->ne[0],
+ (uint32_t)dst->ne[0],
+ mode,
+ alpha,
+ limit
+ }, dryrun);
}
static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
}
-static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
+static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
float * op_params = (float *)dst->op_params;
float scale = op_params[0];
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
- ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
+ ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, {
ncols,
src1 != nullptr ? nrows_y : (uint32_t)0,
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
m0, m1,
n_head_log2,
nrows_x,
+ src2 != nullptr
}, dryrun);
}
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_SWIGLU:
+ case GGML_GLU_OP_SWIGLU_OAI:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
break;
case GGML_OP_REPEAT_BACK:
case GGML_OP_GET_ROWS:
case GGML_OP_ADD:
+ case GGML_OP_ADD_ID:
case GGML_OP_ACC:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
ggml_vk_div(ctx, compute_ctx, src0, src1, node, dryrun);
+ break;
+ case GGML_OP_ADD_ID:
+ ggml_vk_add_id(ctx, compute_ctx, src0, src1, src2, node, dryrun);
+
break;
case GGML_OP_CONCAT:
ggml_vk_concat(ctx, compute_ctx, src0, src1, node, dryrun);
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_SWIGLU:
+ case GGML_GLU_OP_SWIGLU_OAI:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
break;
case GGML_OP_SOFT_MAX:
- ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun);
+ ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node, dryrun);
break;
case GGML_OP_SOFT_MAX_BACK:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
+ case GGML_OP_ADD_ID:
case GGML_OP_CONCAT:
case GGML_OP_UPSCALE:
case GGML_OP_SCALE:
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_SWIGLU:
+ case GGML_GLU_OP_SWIGLU_OAI:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
buf = tensor->buffer;
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_SWIGLU:
+ case GGML_GLU_OP_SWIGLU_OAI:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
return ggml_is_contiguous(op->src[0]) &&
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_MXFP4:
break;
default:
return false;
if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) {
return false;
}
+ // TODO: support attention sinks [TAG_ATTN_SINKS]
+ if (op->src[4]) {
+ return false;
+ }
if (op->src[0]->type != GGML_TYPE_F32) {
return false;
}
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_MXFP4:
return true;
default:
return false;
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
(op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
+ case GGML_OP_ADD_ID:
+ return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->src[2]->type == GGML_TYPE_I32 &&
+ op->type == GGML_TYPE_F32;
case GGML_OP_SILU_BACK:
case GGML_OP_RMS_NORM_BACK:
case GGML_OP_SQR:
--- /dev/null
+#version 450
+
+#extension GL_EXT_control_flow_attributes : require
+
+#include "types.comp"
+
+layout (push_constant) uniform parameter
+{
+ uint ne0;
+ uint ne1;
+ uint s01;
+ uint s02;
+ uint s11;
+ uint s21;
+} p;
+
+#define BLOCK_SIZE 512
+
+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
+layout (binding = 2) readonly buffer Z {int32_t data_c[];};
+layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+ const uint i1 = gl_WorkGroupID.x;
+ const uint i2 = gl_WorkGroupID.y;
+
+ const uint i11 = data_c[i1 + i2 * p.s21];
+
+ const uint s1 = p.ne0;
+ const uint s2 = p.ne0 * p.ne1;
+
+ const uint d0 = i1 * s1 + i2 * s2;
+ const uint a0 = i1 * p.s01 + i2 * p.s02;
+ const uint b0 = i11 * p.s11;
+
+ for (uint i0 = gl_LocalInvocationID.x; i0 < p.ne0; i0 += BLOCK_SIZE) {
+ data_d[d0 + i0] = data_a[a0 + i0] + data_b[b0 + i0];
+ }
+}
#include "generic_unary_head.comp"
#include "dequant_funcs.comp"
-#if defined(DATA_A_IQ4_NL)
-// 16 invocations needed for init_iq4nl_shmem
+#if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4)
+// 16 invocations needed for init_iq_shmem
layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
#else
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
}
#endif
+#if defined(DATA_A_MXFP4)
+vec2 dequantize(uint ib, uint iqs, uint a_offset) {
+ const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
+ return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]);
+}
+vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
+ vec2 v0 = dequantize(ib, iqs, a_offset);
+ vec2 v1 = dequantize(ib, iqs + 1, a_offset);
+ return vec4(v0.x, v0.y, v1.x, v1.y);
+}
+#endif
+
#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
vec2 get_dm(uint ib, uint a_offset) {
return vec2(0, 0);
}
#endif
+#if defined(DATA_A_MXFP4)
+vec2 get_dm(uint ib, uint a_offset) {
+ return vec2(e8m0_to_fp32(data_a[a_offset + ib].e), 0);
+}
+#endif
+
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
vec2 get_dm(uint ib, uint a_offset) {
return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m));
}
#endif
+#if defined(DATA_A_MXFP4)
+layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufMXFP4 {
+ block_mxfp4 block;
+};
+
+float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
+{
+ const float d = e8m0_to_fp32(bl.block.e);
+ const uint idx = coordInBlock[1];
+ const uint iqs = idx & 0xF;
+ const uint shift = (idx & 0x10) >> 2;
+ uint32_t qs = bl.block.qs[iqs];
+ qs >>= shift;
+ qs &= 0xF;
+ float16_t ret = float16_t(kvalues_mxfp4[qs] * d);
+ return ret;
+}
+#endif
+
#if defined(DATA_A_Q4_0)
#define dequantFuncA dequantFuncQ4_0
#elif defined(DATA_A_Q4_1)
#define dequantFuncA dequantFuncIQ4_XS
#elif defined(DATA_A_IQ4_NL)
#define dequantFuncA dequantFuncIQ4_NL
+#elif defined(DATA_A_MXFP4)
+#define dequantFuncA dequantFuncMXFP4
#endif
--- /dev/null
+#version 450
+
+#include "dequant_head.comp"
+
+layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {block_mxfp4 data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
+
+void main() {
+ const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
+
+ init_iq_shmem(gl_WorkGroupSize);
+
+ const uint tid = gl_LocalInvocationID.x % 64;
+ const uint il = tid/32;
+ const uint ir = tid%32;
+ const uint ib = 32*i + ir;
+ if (ib >= p.nel / 32) {
+ return;
+ }
+
+ const uint q_idx = 8*il;
+ const uint b_idx = 1024*i + 32*ir + q_idx;
+
+ const float d = e8m0_to_fp32(data_a[ib].e);
+
+ [[unroll]] for (uint l = 0; l < 8; ++l) {
+ data_b[b_idx + l + 0] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]);
+ data_b[b_idx + l + 16] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]);
+ }
+}
uint ne00;
uint ne20;
uint mode;
+ float alpha;
+ float limit;
} p;
buf_a[buf_idx + 1 ] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]) * d;
buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)]) * d;
buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_iq4nl[vui >> 12]) * d;
+#elif defined(DATA_A_MXFP4)
+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;
+
+ const uint ib = idx / 8;
+ const uint iqs = (idx & 0x07) * 2;
+
+ const float d = e8m0_to_fp32(data_a[ib].e);
+ const uint vui = uint(data_a[ib].qs[iqs]);
+ const uint vui2 = uint(data_a[ib].qs[iqs+1]);
+
+ buf_a[buf_idx ] = FLOAT_TYPE(kvalues_mxfp4[vui & 0xF] * d);
+ buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_mxfp4[vui >> 4] * d);
+ buf_a[buf_idx + 1] = FLOAT_TYPE(kvalues_mxfp4[vui2 & 0xF] * d);
+ buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_mxfp4[vui2 >> 4] * d);
#endif
}
[[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
}
#endif
+#if defined(DATA_A_MXFP4)
+FLOAT_TYPE get_d(uint ib) {
+ return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e));
+}
+#endif
+
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
FLOAT_TYPE_VEC2 get_dm(uint ib) {
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
float m1;
uint n_head_log2;
uint nrows_x;
+ uint has_sinks;
} p;
#include "types.comp"
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
-layout (binding = 2) buffer D {D_TYPE data_d[];};
+layout (binding = 2) readonly buffer Z {float data_c[];};
+layout (binding = 3) buffer D {D_TYPE data_d[];};
shared FLOAT_TYPE vals[BLOCK_SIZE];
const uint h = (rowx / p.ne01) % p.ne02; // head index
const float base = h < p.n_head_log2 ? p.m0 : p.m1;
- const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1;
+ const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1;
slope = pow(base, exp);
}
// Find max
- FLOAT_TYPE max_val = uintBitsToFloat(0xFF800000);
+ FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02];
// Cache values while we compute the max, so we don't need to read them
// again when we're ready to compute exp(x-max).
}
sum = vals[0];
+ if (p.has_sinks != 0) {
+ sum += FLOAT_TYPE(exp(FLOAT_TYPE(data_c[i02]) - max_val));
+ }
+
FLOAT_TYPE rcpdivisor = 1.0/sum;
[[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
--- /dev/null
+#version 450
+
+#include "glu_head.comp"
+
+float op(float a, float b) {
+ float xi = min(a, p.limit);
+ float gi = max(min(b, p.limit), -p.limit);
+
+ float out_glu = xi / (1.0f + exp(-xi * p.alpha));
+ out_glu = out_glu * (1.0f + gi);
+ return out_glu;
+}
+
+#include "glu_main.comp"
#define A_TYPE_PACKED16 block_iq4_nl_packed16
#endif
+#define QUANT_K_MXFP4 32
+#define QUANT_R_MXFP4 2
+
+struct block_mxfp4
+{
+ uint8_t e;
+ uint8_t qs[QUANT_K_MXFP4/2];
+};
+
+//struct block_mxfp4_packed16
+//{
+// uint8_t e;
+// uint16_t qs[QUANT_K_MXFP4/2/2];
+//};
+
+#if defined(DATA_A_MXFP4)
+#define QUANT_K QUANT_K_MXFP4
+#define QUANT_R QUANT_R_MXFP4
+#define QUANT_AUXF 1
+#define A_TYPE block_mxfp4
+//#define A_TYPE_PACKED16 block_mxfp4_packed16
+#endif
+
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS)
const int8_t kvalues_iq4nl_const[16] = {
int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10),
}
#endif
+#if defined(DATA_A_MXFP4)
+const FLOAT_TYPE kvalues_mxfp4_const[16] = {
+ FLOAT_TYPE(0.0f), FLOAT_TYPE(0.5f), FLOAT_TYPE(1.0f), FLOAT_TYPE(1.5f), FLOAT_TYPE(2.0f), FLOAT_TYPE(3.0f), FLOAT_TYPE(4.0f), FLOAT_TYPE(6.0f),
+ FLOAT_TYPE(-0.0f), FLOAT_TYPE(-0.5f), FLOAT_TYPE(-1.0f), FLOAT_TYPE(-1.5f), FLOAT_TYPE(-2.0f), FLOAT_TYPE(-3.0f), FLOAT_TYPE(-4.0f), FLOAT_TYPE(-6.0f)
+};
+
+shared FLOAT_TYPE kvalues_mxfp4[16];
+
+#define NEEDS_INIT_IQ_SHMEM
+void init_iq_shmem(uvec3 wgsize)
+{
+ // copy the table into shared memory and sync
+ for (uint i = gl_LocalInvocationIndex.x; i < kvalues_mxfp4.length(); i += wgsize.x) {
+ kvalues_mxfp4[i] = kvalues_mxfp4_const[i];
+ }
+ barrier();
+}
+#endif
+
// returns the bfloat value in the low 16b.
// See ggml_compute_fp32_to_bf16
uint32_t fp32_to_bf16(float f)
return uintBitsToFloat(u << 16);
}
+float e8m0_to_fp32(uint8_t x) {
+ uint32_t bits;
+
+ if (x == 0) {
+ bits = 0x00400000;
+ } else {
+ bits = x;
+ bits = bits << 23;
+ }
+
+ return uintBitsToFloat(bits);
+}
+
#endif // !defined(GGML_TYPES_COMP)
"iq3_s",
"iq4_xs",
"iq4_nl",
+ "mxfp4",
"bf16",
};
CloseHandle(pi.hProcess);
CloseHandle(pi.hThread);
#else
-int stdout_pipe[2];
+ int stdout_pipe[2];
int stderr_pipe[2];
if (pipe(stdout_pipe) != 0 || pipe(stderr_pipe) != 0) {
std::string load_vec_quant = "2";
if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
load_vec_quant = "8";
- else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl"))
+ else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4"))
load_vec_quant = "4";
if (tname == "bf16") {
string_to_spv("reglu_f32" + suffix, "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("swiglu_f16" + suffix, "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("swiglu_f32" + suffix, "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
+ string_to_spv("swiglu_oai_f16" + suffix, "swiglu_oai.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
+ string_to_spv("swiglu_oai_f32" + suffix, "swiglu_oai.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("geglu_erf_f16" + suffix, "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("geglu_erf_f32" + suffix, "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("geglu_quick_f16" + suffix,"geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("roll_f32", "roll.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
+ string_to_spv("add_id_f32", "add_id.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
+
for (auto &c : compiles) {
c.wait();
}
#endif
}
-static void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * GGML_RESTRICT x, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc);
-static void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc);
-static void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * GGML_RESTRICT x, size_t bx, ggml_bf16_t * GGML_RESTRICT y, size_t by, int nrc);
static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
[GGML_TYPE_I8] = {
.is_quantized = true,
.from_float_ref = (ggml_from_float_t) quantize_row_q8_1_ref,
},
+ [GGML_TYPE_MXFP4] = {
+ .type_name = "mxfp4",
+ .blck_size = QK_MXFP4,
+ .type_size = sizeof(block_mxfp4),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_mxfp4,
+ .from_float_ref = (ggml_from_float_t)quantize_row_mxfp4_ref,
+ },
[GGML_TYPE_Q2_K] = {
.type_name = "q2_K",
.blck_size = QK_K,
"DUP",
"ADD",
+ "ADD_ID",
"ADD1",
"ACC",
"SUB",
"GLU",
};
-static_assert(GGML_OP_COUNT == 86, "GGML_OP_COUNT != 86");
+static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
"x",
"x+y",
+ "x[i]+y",
"x+y",
"view(x,nb,offset)+=y->x",
"x-y",
"glu(x)",
};
-static_assert(GGML_OP_COUNT == 86, "GGML_OP_COUNT != 86");
+static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
"REGLU",
"GEGLU",
"SWIGLU",
+ "SWIGLU_OAI",
"GEGLU_ERF",
"GEGLU_QUICK",
};
-static_assert(GGML_GLU_OP_COUNT == 5, "GGML_GLU_OP_COUNT != 5");
+static_assert(GGML_GLU_OP_COUNT == 6, "GGML_GLU_OP_COUNT != 6");
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break;
case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break;
case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break;
+ case GGML_FTYPE_MOSTLY_MXFP4: wtype = GGML_TYPE_MXFP4; break;
case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break;
case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break;
case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break;
return ggml_add_cast_impl(ctx, a, b, type);
}
+struct ggml_tensor * ggml_add_id(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ struct ggml_tensor * ids) {
+
+ GGML_ASSERT(a->ne[0] == b->ne[0]);
+ GGML_ASSERT(a->ne[1] == ids->ne[0]);
+ GGML_ASSERT(a->ne[2] == ids->ne[1]);
+ GGML_ASSERT(ids->type == GGML_TYPE_I32);
+
+ struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_ADD_ID;
+ result->src[0] = a;
+ result->src[1] = b;
+ result->src[2] = ids;
+
+ return result;
+}
+
// ggml_add1
static struct ggml_tensor * ggml_add1_impl(
return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_QUICK, false);
}
+struct ggml_tensor * ggml_swiglu_oai(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ float alpha,
+ float limit) {
+ struct ggml_tensor * result = ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU_OAI, false);
+ ggml_set_op_params_f32(result, 2, alpha);
+ ggml_set_op_params_f32(result, 3, limit);
+
+ return result;
+}
+
// ggml_norm
static struct ggml_tensor * ggml_norm_impl(
return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
}
+void ggml_soft_max_add_sinks(
+ struct ggml_tensor * a,
+ struct ggml_tensor * sinks) {
+ if (!sinks) {
+ a->src[2] = NULL;
+ return;
+ }
+
+ GGML_ASSERT(a->op == GGML_OP_SOFT_MAX);
+ GGML_ASSERT(a->src[2] == NULL);
+ GGML_ASSERT(a->src[0]->ne[2] == sinks->ne[0]);
+ GGML_ASSERT(sinks->type == GGML_TYPE_F32);
+
+ a->src[2] = sinks;
+}
+
// ggml_soft_max_ext_back
static struct ggml_tensor * ggml_soft_max_ext_back_impl(
return (enum ggml_prec) prec_i32;
}
+void ggml_flash_attn_ext_add_sinks(
+ struct ggml_tensor * a,
+ struct ggml_tensor * sinks) {
+ if (!sinks) {
+ a->src[4] = NULL;
+ return;
+ }
+
+ GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
+ GGML_ASSERT(a->src[4] == NULL);
+ GGML_ASSERT(a->src[0]->ne[2] == sinks->ne[0]);
+ GGML_ASSERT(sinks->type == GGML_TYPE_F32);
+
+ a->src[4] = sinks;
+}
+
// ggml_flash_attn_back
struct ggml_tensor * ggml_flash_attn_back(
case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_MXFP4: result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
#define VARS_TO_STR10(a, b, c, d, e, f, g, h, i, j) VAR_TO_STR(a) + "," + VARS_TO_STR9(b, c, d, e, f, g, h, i, j)
#define VARS_TO_STR11(a, b, c, d, e, f, g, h, i, j, k) VAR_TO_STR(a) + "," + VARS_TO_STR10(b, c, d, e, f, g, h, i, j, k)
#define VARS_TO_STR12(a, b, c, d, e, f, g, h, i, j, k, l) VAR_TO_STR(a) + "," + VARS_TO_STR11(b, c, d, e, f, g, h, i, j, k, l)
+#define VARS_TO_STR13(a, b, c, d, e, f, g, h, i, j, k, l, m) VAR_TO_STR(a) + "," + VARS_TO_STR12(b, c, d, e, f, g, h, i, j, k, l, m)
#ifdef GGML_USE_SYCL
static bool inline _isinf(float f) {
}
};
+struct test_swiglu_oai : public test_case {
+ const ggml_type type;
+ const std::array<int64_t, 4> ne_a;
+ int v; // view (1 : non-contiguous a)
+ float alpha;
+ float limit;
+
+ std::string vars() override {
+ return VARS_TO_STR5(type, ne_a, v, alpha, limit);
+ }
+
+ test_swiglu_oai(ggml_type type = GGML_TYPE_F32,
+ std::array<int64_t, 4> ne_a = {128, 2, 2, 2},
+ int v = 0,
+ float alpha = 1.702f,
+ float limit = 7.0f)
+ : type(type), ne_a(ne_a), v(v), alpha(alpha), limit(limit) {}
+
+ ggml_tensor * build_graph(ggml_context * ctx) override {
+ ggml_tensor * a;
+ ggml_tensor * b;
+ if (v & 1) {
+ auto ne = ne_a; ne[0] *= 3;
+ a = ggml_new_tensor(ctx, type, 4, ne.data());
+ ggml_set_param(a);
+ ggml_set_name(a, "a");
+
+ a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
+ ggml_set_name(a, "view_of_a");
+
+ b = ggml_new_tensor(ctx, type, 4, ne.data());
+ ggml_set_param(b);
+ ggml_set_name(b, "b");
+
+ b = ggml_view_4d(ctx, b, ne_a[0], ne_a[1], ne_a[2], ne_a[3], b->nb[1], b->nb[2], b->nb[3], 0);
+ ggml_set_name(a, "view_of_b");
+ } else {
+ a = ggml_new_tensor(ctx, type, 4, ne_a.data());
+ ggml_set_param(a);
+ ggml_set_name(a, "a");
+
+ b = ggml_new_tensor(ctx, type, 4, ne_a.data());
+ ggml_set_param(b);
+ ggml_set_name(b, "b");
+ }
+
+ ggml_tensor * out = ggml_swiglu_oai(ctx, a, b, alpha, limit);
+ ggml_set_name(out, "out");
+
+ return out;
+ }
+
+ void initialize_tensors(ggml_context * ctx) override {
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+ // test extended range of values to check for NaNs in GELU
+ init_tensor_uniform(t, -150.f, 150.f);
+ }
+ }
+};
+
// GGML_OP_GET_ROWS
struct test_get_rows : public test_case {
const ggml_type type;
}
};
+// GGML_OP_ADD_ID
+struct test_add_id : public test_case {
+ const ggml_type type_a;
+ const ggml_type type_b;
+ const int64_t n_embd;
+ const int64_t n_experts;
+ const int64_t n_experts_used;
+ const int64_t n_token;
+
+ std::string vars() override {
+ return VARS_TO_STR6(type_a, type_b, n_embd, n_experts, n_experts_used, n_token);
+ }
+
+ size_t op_size(ggml_tensor * t) override {
+ return ggml_nbytes(t) + ggml_nbytes(t->src[0]) + ggml_nbytes(t->src[2]);
+ }
+
+ test_add_id(ggml_type type_a = GGML_TYPE_F32,
+ ggml_type type_b = GGML_TYPE_F32,
+ int64_t n_embd = 128,
+ int64_t n_experts = 16,
+ int64_t n_experts_used = 8,
+ int64_t n_token = 10)
+ : type_a(type_a), type_b(type_b), n_embd(n_embd),
+ n_experts(n_experts), n_experts_used(n_experts_used), n_token(n_token) {}
+
+ ggml_tensor * build_graph(ggml_context * ctx) override {
+ ggml_tensor * a = ggml_new_tensor_3d(ctx, type_a, n_embd, n_experts_used, n_token);
+ ggml_tensor * b = ggml_new_tensor_2d(ctx, type_b, n_embd, n_experts);
+ ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_experts, n_token);
+ if (n_experts_used != n_experts) {
+ ids = ggml_view_2d(ctx, ids, n_experts_used, n_token, ids->nb[1], 0);
+ ggml_set_name(ids, "view_of_ids");
+ }
+
+ ggml_tensor * out = ggml_add_id(ctx, a, b, ids);
+ ggml_set_name(out, "out");
+ return out;
+ }
+
+ void initialize_tensors(ggml_context * ctx) override {
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+ if (t->type == GGML_TYPE_I32) {
+ if (ggml_is_view_op(t->op)) { continue; }
+ std::random_device rd;
+ std::default_random_engine rng(rd());
+ // ids
+ for (int64_t r = 0; r < ggml_nrows(t); r++) {
+ std::vector<int32_t> data(t->ne[0]);
+ for (int i = 0; i < t->ne[0]; i++) {
+ data[i] = i % n_experts;
+ }
+ std::shuffle(data.begin(), data.end(), rng);
+ ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
+ }
+ } else {
+ init_tensor_uniform(t);
+ }
+ }
+ }
+};
+
// GGML_OP_ADD1
struct test_add1 : public test_case {
const ggml_type type;
}
void initialize_tensors(ggml_context * ctx) override {
- std::random_device rd;
- std::default_random_engine rng(rd());
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->type == GGML_TYPE_I32) {
if (ggml_is_view_op(t->op)) { continue; }
+ std::random_device rd;
+ std::default_random_engine rng(rd());
// ids
for (int64_t r = 0; r < ggml_nrows(t); r++) {
std::vector<int32_t> data(t->ne[0]);
const ggml_type type;
const std::array<int64_t, 4> ne;
const bool mask;
+ const bool sinks;
const ggml_type m_prec;
const std::array<int64_t, 2> nr23; // broadcast only dims 2 and 3
const float scale;
const float max_bias;
std::string vars() override {
- return VARS_TO_STR7(type, ne, mask, m_prec, nr23, scale, max_bias);
+ return VARS_TO_STR8(type, ne, mask, sinks, m_prec, nr23, scale, max_bias);
}
// the 1024 test with bias occasionally fails:
test_soft_max(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 5, 4, 3},
bool mask = false,
+ bool sinks = false,
ggml_type m_prec = GGML_TYPE_F32,
std::array<int64_t, 2> nr23 = {1, 1},
float scale = 1.0f,
float max_bias = 0.0f)
- : type(type), ne(ne), mask(mask), m_prec(m_prec), nr23(nr23), scale(scale), max_bias(max_bias) {}
+ : type(type), ne(ne), mask(mask), sinks(sinks), m_prec(m_prec), nr23(nr23), scale(scale), max_bias(max_bias) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2]*nr23[0], ne[3]*nr23[1]);
ggml_set_name(mask, "mask");
}
+ ggml_tensor * sinks = nullptr;
+ if (this->sinks) {
+ sinks = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[2]*nr23[0]);
+ ggml_set_name(sinks, "sinks");
+ }
+
ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, scale, max_bias);
+ ggml_soft_max_add_sinks(out, sinks);
ggml_set_name(out, "out");
return out;
const int64_t nb; // batch size
const bool mask; // use mask
+ const bool sinks; // use sinks
const float max_bias; // ALiBi
const float logit_softcap; // Gemma 2
std::array<int32_t, 4> permute;
std::string vars() override {
- return VARS_TO_STR12(hsk, hsv, nh, nr23, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute);
+ return VARS_TO_STR13(hsk, hsv, nh, nr23, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV, permute);
}
double max_nmse_err() override {
}
test_flash_attn_ext(int64_t hsk = 128, int64_t hsv = 128, int64_t nh = 32, std::array<int64_t, 2> nr23 = {1, 1}, int64_t kv = 96, int64_t nb = 8,
- bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32,
+ bool mask = true, bool sinks = false, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32,
ggml_type type_KV = GGML_TYPE_F16, std::array<int32_t, 4> permute = {0, 1, 2, 3})
- : hsk(hsk), hsv(hsv), nh(nh), nr23(nr23), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
+ : hsk(hsk), hsv(hsv), nh(nh), nr23(nr23), kv(kv), nb(nb), mask(mask), sinks(sinks), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
const int64_t hsk_padded = GGML_PAD(hsk, ggml_blck_size(type_KV));
ggml_set_name(m, "m");
}
+ ggml_tensor * s = nullptr;
+ if (sinks) {
+ s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, q->ne[2]);
+ ggml_set_name(s, "s");
+ }
+
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hsk), max_bias, logit_softcap);
- ggml_flash_attn_ext_set_prec(out, prec);
+ ggml_flash_attn_ext_add_sinks(out, s);
+ ggml_flash_attn_ext_set_prec (out, prec);
ggml_set_name(out, "out");
return out;
}
+ void initialize_tensors(ggml_context * ctx) override {
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+ if (strcmp(t->name, "s") == 0) {
+ // make the sink values more noticable in order to trigger a test failure when the implementation is wrong
+ init_tensor_uniform(t, -10.0f, 10.0f);
+ } else {
+ init_tensor_uniform(t);
+ }
+ }
+ }
+
bool grad_precise() override {
return true;
}
GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
GGML_TYPE_Q8_0,
+ GGML_TYPE_MXFP4,
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
GGML_TYPE_Q6_K,
GGML_TYPE_Q4_0,
GGML_TYPE_Q4_1, // for I8MM tests
GGML_TYPE_Q4_K,
+ GGML_TYPE_MXFP4, // TODO: or "other"
GGML_TYPE_IQ2_XXS
};
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
for (int v : {0, 1}) {
for (int op = 0; op < GGML_GLU_OP_COUNT; op++) {
+ if (op == GGML_GLU_OP_SWIGLU_OAI) {
+ // SWIGLU_OAI is handled separately
+ continue;
+ }
+
for (bool swapped : {false, true}) {
test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 128, 2, 2, 2 }, v, swapped));
test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 5, 7, 11, 13 }, v, swapped));
}
}
+ for (int v : {0, 1}) {
+ for (float alpha : {.5f, 1.702f}) {
+ for (float limit : {2.0f, 7.0f}) {
+ test_cases.emplace_back(new test_swiglu_oai(GGML_TYPE_F32, { 128, 2, 2, 2 }, v, alpha, limit));
+ }
+ }
+ }
+
test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, false));
for (ggml_type type : all_types) {
for (int b : {1, 7}) {
}
}
+ // add_id
+ for (ggml_type type_a : {GGML_TYPE_F32}) {
+ for (ggml_type type_b : {GGML_TYPE_F32}) {
+ for (int n_mats : {4, 8}) {
+ for (int n_used : {1, 2, 4}) {
+ for (int n_embd : {32, 129}) {
+ for (int n_token : {1, 32, 129}) {
+ test_cases.emplace_back(new test_add_id(type_a, type_b, n_embd, n_mats, n_used, n_token));
+ }
+ }
+ }
+ }
+ }
+ }
+
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
test_cases.emplace_back(new test_sqr(type));
test_cases.emplace_back(new test_sqrt(type));
}
#endif
for (bool mask : {false, true}) {
- for (float max_bias : {0.0f, 8.0f}) {
- if (!mask && max_bias > 0.0f) continue;
- for (float scale : {1.0f, 0.1f}) {
- for (int64_t ne0 : {16, 1024}) {
- for (int64_t ne1 : {16, 1024}) {
- if (mask) {
- for (ggml_type m_prec : {GGML_TYPE_F32, GGML_TYPE_F16}) {
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, m_prec, {1, 1}, scale, max_bias));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {1, 1}, scale, max_bias));
-
- if (ne0 <= 32 && ne1 <= 32) {
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 3}, mask, m_prec, {3, 1}, scale, max_bias));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {2, 3}, scale, max_bias));
+ for (bool sinks : {false, true}) {
+ for (float max_bias : {0.0f, 8.0f}) {
+ if (!mask && max_bias > 0.0f) continue;
+ for (float scale : {1.0f, 0.1f}) {
+ for (int64_t ne0 : {16, 1024}) {
+ for (int64_t ne1 : {16, 1024}) {
+ if (mask) {
+ for (ggml_type m_prec : {GGML_TYPE_F32, GGML_TYPE_F16}) {
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, sinks, m_prec, {1, 1}, scale, max_bias));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, sinks, m_prec, {1, 1}, scale, max_bias));
+
+ if (ne0 <= 32 && ne1 <= 32) {
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 3}, mask, sinks, m_prec, {3, 1}, scale, max_bias));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, sinks, m_prec, {2, 3}, scale, max_bias));
+ }
}
+ } else {
+ /* The precision of mask here doesn't matter as boolean mask is false */
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, sinks, GGML_TYPE_F32, {1, 1}, scale, max_bias));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, sinks, GGML_TYPE_F32, {1, 1}, scale, max_bias));
}
- } else {
- /* The precision of mask here doesn't matter as boolean mask is false */
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, GGML_TYPE_F32, {1, 1}, scale, max_bias));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, GGML_TYPE_F32, {1, 1}, scale, max_bias));
}
}
}
}
}
}
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, true, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, false, GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, true, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, true, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, false, GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, true, GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, true, GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f));
for (float max_bias : {0.0f, 8.0f}) {
for (float scale : {1.0f, 0.1f}) {
if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA
for (bool mask : { true, false } ) {
- for (float max_bias : { 0.0f, 8.0f }) {
- if (!mask && max_bias > 0.0f) continue;
- for (float logit_softcap : {0.0f, 10.0f}) {
- if (hsk != 128 && logit_softcap != 0.0f) continue;
- for (int nh : { 4, }) {
- for (int nr3 : { 1, 3, }) {
- if (hsk > 64 && nr3 > 1) continue; // skip broadcast for large head sizes
- for (int nr2 : { 1, 4, 16 }) {
- if (nr2 == 16 && hsk != 128) continue;
- for (int kv : { 512, 1024, }) {
- if (nr2 != 1 && kv != 512) continue;
- for (int nb : { 1, 3, 32, 35, }) {
- for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
- if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
- for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
- test_cases.emplace_back(new test_flash_attn_ext(
- hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, max_bias, logit_softcap, prec, type_KV));
- // run fewer test cases permuted
- if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
+ for (bool sinks : { true, false } ) {
+ for (float max_bias : { 0.0f, 8.0f }) {
+ if (!mask && max_bias > 0.0f) continue;
+ for (float logit_softcap : {0.0f, 10.0f}) {
+ if (hsk != 128 && logit_softcap != 0.0f) continue;
+ for (int nh : { 4, }) {
+ for (int nr3 : { 1, 3, }) {
+ if (hsk > 64 && nr3 > 1) continue; // skip broadcast for large head sizes
+ for (int nr2 : { 1, 4, 16 }) {
+ if (nr2 == 16 && hsk != 128) continue;
+ for (int kv : { 512, 1024, }) {
+ if (nr2 != 1 && kv != 512) continue;
+ for (int nb : { 1, 3, 32, 35, }) {
+ for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
+ if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
+ for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
test_cases.emplace_back(new test_flash_attn_ext(
- hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
+ hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV));
+ // run fewer test cases permuted
+ if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
+ test_cases.emplace_back(new test_flash_attn_ext(
+ hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
+ }
}
}
}
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3}));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 256, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 256, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 10, 1, 1}));
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
for (int kv : { 4096, 8192, 16384, }) {
for (int hs : { 64, 128, }) {
for (int nr : { 1, 4, }) {
- test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, {nr, 1}, kv, 1, true, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
+ test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, {nr, 1}, kv, 1, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
}
}
}
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, {256, 256, 3, 1}));
+
+ for (int n_token : {1, 512}) {
+ test_cases.emplace_back(new test_add_id(GGML_TYPE_F32, GGML_TYPE_F32, 2880, 128, 4, n_token));
+ test_cases.emplace_back(new test_add_id(GGML_TYPE_F32, GGML_TYPE_F32, 2880, 32, 4, n_token));
+ }
+
return test_cases;
}