"controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n"
"- none: leaves thoughts unparsed in `message.content`\n"
"- deepseek: puts thoughts in `message.reasoning_content` (except in streaming mode, which behaves as `none`)\n"
- "(default: deepseek)",
+ "(default: auto)",
[](common_params & params, const std::string & value) {
/**/ if (value == "deepseek") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; }
else if (value == "deepseek-legacy") { params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY; }
else if (value == "none") { params.reasoning_format = COMMON_REASONING_FORMAT_NONE; }
+ else if (value == "auto") { params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; }
else { throw std::invalid_argument("invalid value"); }
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK"));
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
+ case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS";
default:
throw std::runtime_error("Unknown chat format");
}
const char * common_reasoning_format_name(common_reasoning_format format) {
switch (format) {
case COMMON_REASONING_FORMAT_NONE: return "none";
+ case COMMON_REASONING_FORMAT_AUTO: return "auto";
case COMMON_REASONING_FORMAT_DEEPSEEK: return "deepseek";
case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: return "deepseek-legacy";
default:
tool_calls_end);
}
+static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) {
+ common_chat_params data;
+ auto prompt = apply(tmpl, inputs);
+
+ data.prompt = prompt;
+ data.format = COMMON_CHAT_FORMAT_GPT_OSS;
+
+ // TODO: support tool calls in GPT-OSS?
+
+ return data;
+}
+static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
+ // TODO @ngxson : this won't work with --special enabled, we should fix that
+ builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|start|>assistant<|channel|>final<|message|>");
+ if (!builder.syntax().parse_tool_calls) {
+ builder.add_content(builder.consume_rest());
+ return;
+ }
+}
+
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
LOG_DBG("%s\n", __func__);
common_chat_params data;
return common_chat_params_init_hermes_2_pro(tmpl, params);
}
+ // GPT-OSS
+ if (src.find("<|channel|>") != std::string::npos && params.json_schema.is_null()) {
+ return common_chat_params_init_gpt_oss(tmpl, params);
+ }
+
// Use generic handler when mixing tools + JSON schema.
// TODO: support that mix in handlers below.
if ((params.tools.is_array() && params.json_schema.is_object())) {
case COMMON_CHAT_FORMAT_COMMAND_R7B:
common_chat_parse_command_r7b(builder);
break;
+ case COMMON_CHAT_FORMAT_GPT_OSS:
+ common_chat_parse_gpt_oss(builder);
+ break;
default:
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
}
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
COMMON_CHAT_FORMAT_HERMES_2_PRO,
COMMON_CHAT_FORMAT_COMMAND_R7B,
+ COMMON_CHAT_FORMAT_GPT_OSS,
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
};
enum common_reasoning_format {
COMMON_REASONING_FORMAT_NONE,
+ COMMON_REASONING_FORMAT_AUTO,
COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in <think> tags in stream mode
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
};
std::string chat_template = ""; // NOLINT
bool use_jinja = false; // NOLINT
bool enable_chat_template = true;
- common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
+ common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO;
int reasoning_budget = -1;
bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response
self.gguf_writer.add_chat_template(chat_template)
+@ModelBase.register("GptOssForCausalLM")
+class GptOssModel(TextModel):
+ model_arch = gguf.MODEL_ARCH.GPT_OSS
+
+ def transform_nibble_layout(self, tensor):
+ assert tensor.dtype == torch.uint8
+ assert tensor.shape[-1] == 16
+ # swap nibbles
+ t_lo = tensor & 0x0F
+ t_hi = tensor & 0xF0
+ t_swapped = (t_lo << 4) | (t_hi >> 4)
+ tensor = t_swapped
+ # transform aaaa...bbbb... to abababab...
+ blk_a, blk_b = tensor.chunk(2, dim=-1)
+ # get a_
+ blk_a0 = (blk_a & 0xF0).view(-1, 1)
+ blk_a1 = (blk_a << 4).view(-1, 1)
+ blk_a = torch.stack((blk_a0, blk_a1), dim=2).view(tensor.shape)
+ # get _b
+ blk_b0 = (blk_b >> 4).view(-1, 1)
+ blk_b1 = (blk_b & 0x0F).view(-1, 1)
+ blk_b = torch.stack((blk_b0, blk_b1), dim=2).view(tensor.shape)
+ # swap once more
+ out = blk_a | blk_b
+ out_h = out & 0xF0
+ out_l = out & 0x0F
+ out = (out_h >> 4) | (out_l << 4)
+ return out
+
+ def repack_mxfp4(self, new_name: str, blocks: Tensor, scales: Tensor):
+ assert blocks.dtype == torch.uint8
+ assert scales.dtype == torch.uint8
+ scales = scales.unsqueeze(-1)
+ assert len(blocks.shape) == 4
+ assert len(scales.shape) == 4
+ blocks = self.transform_nibble_layout(blocks)
+ new_data = torch.concat((scales, blocks), dim=-1)
+ new_shape = [new_data.shape[0], new_data.shape[1], new_data.shape[2] * 32]
+ logger.info(f"Repacked {new_name} with shape {new_shape} and quantization MXFP4")
+ # flatten last dim
+ new_data = new_data.view(new_data.shape[0], new_data.shape[1], new_data.shape[2] * new_data.shape[3])
+ new_data = new_data.numpy()
+ self.gguf_writer.add_tensor(new_name, new_data, raw_dtype=gguf.GGMLQuantizationType.MXFP4)
+
+ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
+ blocks0: Tensor = torch.zeros(1)
+ blocks1: Tensor = torch.zeros(1)
+ found_mxfp4_tensors = False
+ # we assume that tensors are loaded in the correct order
+ for name, data_torch in self.get_tensors():
+ if "mlp.experts.down_proj_blocks" in name:
+ blocks0 = data_torch
+ elif "mlp.experts.down_proj_scales" in name:
+ new_name = self.map_tensor_name(name.replace("_scales", ".weight"))
+ self.repack_mxfp4(new_name, blocks0, data_torch)
+ found_mxfp4_tensors = True
+ elif "mlp.experts.gate_up_proj_blocks" in name:
+ blocks0, blocks1 = data_torch[:, ::2, :, :], data_torch[:, 1::2, :, :]
+ elif "mlp.experts.gate_up_proj_scales" in name:
+ scales0, scales1 = data_torch[:, ::2, :], data_torch[:, 1::2, :]
+ new_name_gate = self.map_tensor_name(name.replace("gate_up_proj_scales", "gate_proj.weight"))
+ new_name_up = self.map_tensor_name(name.replace("gate_up_proj_scales", "up_proj.weight"))
+ self.repack_mxfp4(new_name_gate, blocks0, scales0)
+ self.repack_mxfp4(new_name_up, blocks1, scales1)
+ found_mxfp4_tensors = True
+ if not found_mxfp4_tensors:
+ raise ValueError("No MXFP4 tensors found in the model. Please make sure you are using MXFP4 model.")
+ return []
+
+ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+ del bid # unused
+
+ if "sinks" in name:
+ name += ".weight"
+
+ # correct naming for down_proj
+ if "down_proj" in name:
+ if name.endswith("_bias"):
+ name = name.replace("down_proj_bias", "down_proj.bias")
+ else:
+ return []
+
+ # split the gate_up into gate and up
+ if "gate_up_proj" in name:
+ if name.endswith("_bias"):
+ name_up = name.replace("gate_up_proj_bias", "up_proj.bias")
+ name_gate = name.replace("gate_up_proj_bias", "gate_proj.bias")
+ gate_proj_bias, up_proj_bias = data_torch[..., ::2], data_torch[..., 1::2]
+ return [
+ (self.map_tensor_name(name_gate), gate_proj_bias),
+ (self.map_tensor_name(name_up), up_proj_bias)
+ ]
+ else:
+ return []
+
+ return [(self.map_tensor_name(name), data_torch)]
+
+ def set_vocab(self):
+ self._set_vocab_gpt2()
+
+ def set_gguf_parameters(self):
+ super().set_gguf_parameters()
+ self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
+ self.gguf_writer.add_expert_feed_forward_length(self.hparams["intermediate_size"])
+
+ rope_scaling = self.hparams.get("rope_scaling") or {}
+ rope_type = rope_scaling.get("rope_type", rope_scaling.get("type"))
+ assert rope_type == "yarn", f"GPT-OSS only supports yarn rope scaling, got {rope_type}"
+ self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
+ self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
+ self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling.get("original_max_position_embeddings", 4096))
+
+
@ModelBase.register("Lfm2ForCausalLM")
@ModelBase.register("LFM2ForCausalLM")
class LFM2Model(TextModel):
_dtype_map: dict[torch.dtype, type] = {
torch.float16: np.float16,
torch.float32: np.float32,
+ torch.uint8: np.uint8,
}
# used for safetensors slices
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
+#define GGML_TENSOR_TERNARY_OP_LOCALS \
+ GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
+ GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
+ GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
+ GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
+ GGML_TENSOR_LOCALS(int64_t, ne2, src2, ne) \
+ GGML_TENSOR_LOCALS(size_t, nb2, src2, nb) \
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
+
#define GGML_TENSOR_BINARY_OP_LOCALS01 \
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
// GGML_TYPE_IQ4_NL_4_4 = 36,
// GGML_TYPE_IQ4_NL_4_8 = 37,
// GGML_TYPE_IQ4_NL_8_8 = 38,
- GGML_TYPE_COUNT = 39,
+ GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block)
+ GGML_TYPE_COUNT = 40,
};
// precision
GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
+ GGML_FTYPE_MOSTLY_MXFP4 = 25, // except 1d tensors
};
// available tensor operations:
GGML_OP_DUP,
GGML_OP_ADD,
+ GGML_OP_ADD_ID,
GGML_OP_ADD1,
GGML_OP_ACC,
GGML_OP_SUB,
GGML_GLU_OP_REGLU,
GGML_GLU_OP_GEGLU,
GGML_GLU_OP_SWIGLU,
+ GGML_GLU_OP_SWIGLU_OAI,
GGML_GLU_OP_GEGLU_ERF,
GGML_GLU_OP_GEGLU_QUICK,
struct ggml_tensor * b,
enum ggml_type type);
+ // dst[i0, i1, i2] = a[i0, i1, i2] + b[i0, ids[i1, i2]]
+ GGML_API struct ggml_tensor * ggml_add_id(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ struct ggml_tensor * ids);
+
GGML_API struct ggml_tensor * ggml_add1(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * a,
struct ggml_tensor * b);
+ GGML_API struct ggml_tensor * ggml_swiglu_oai(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ float alpha,
+ float limit);
+
// normalize along rows
GGML_API struct ggml_tensor * ggml_norm(
struct ggml_context * ctx,
float scale,
float max_bias);
+ GGML_API void ggml_soft_max_add_sinks(
+ struct ggml_tensor * a,
+ struct ggml_tensor * sinks);
+
GGML_API struct ggml_tensor * ggml_soft_max_ext_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
GGML_API enum ggml_prec ggml_flash_attn_ext_get_prec(
const struct ggml_tensor * a);
+ GGML_API void ggml_flash_attn_ext_add_sinks(
+ struct ggml_tensor * a,
+ struct ggml_tensor * sinks);
+
// TODO: needs to be adapted to ggml_flash_attn_ext
GGML_API struct ggml_tensor * ggml_flash_attn_back(
struct ggml_context * ctx,
case GGML_OP_DIAG_MASK_ZERO:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_ADD:
+ case GGML_OP_ADD_ID:
case GGML_OP_ADD1:
case GGML_OP_SUB:
case GGML_OP_MUL:
memcpy(&bias, (float*)op->op_params + 1, sizeof(float));
return bias == 0.0f; // TODO: support bias != 0.0f
case GGML_OP_SOFT_MAX:
+ // TODO: support attention sinks [TAG_ATTN_SINKS]
+ if (op->src[2]) {
+ return false;
+ }
// TODO: support broadcast
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
if(op->type != GGML_TYPE_F16 && op->type != GGML_TYPE_F32 && op->type != GGML_TYPE_BF16){
return false;
}
+ // TODO: support attention sinks [TAG_ATTN_SINKS]
+ if (op->src[4]) {
+ return false;
+ }
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
// different head sizes of K and V are not supported yet
return false;
#define QI4_1 (QK4_1 / (4 * QR4_1))
#define QR4_1 2
+#define QI_MXFP4 (QK_MXFP4 / (4 * QR_MXFP4))
+#define QR_MXFP4 2
+
#define QI5_0 (QK5_0 / (4 * QR5_0))
#define QR5_0 2
} block_q4_1;
static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding");
+#define QK_MXFP4 32
+typedef struct {
+ uint8_t e; // E8M0
+ uint8_t qs[QK_MXFP4/2];
+} block_mxfp4;
+static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + QK_MXFP4/2, "wrong mxfp4 block size/padding");
+
#define QK5_0 32
typedef struct {
ggml_half d; // delta
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
GGML_TABLE_END()
+// TODO: fix name to kvalues_iq4_nl
GGML_TABLE_BEGIN(int8_t, kvalues_iq4nl, 16)
-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113,
GGML_TABLE_END()
+// e2m1 values (doubled)
+// ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
+GGML_TABLE_BEGIN(int8_t, kvalues_mxfp4, 16)
+ 0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12,
+GGML_TABLE_END()
+
#define NGRID_IQ1S 2048
#define IQ1S_DELTA 0.125f
#define IQ1M_DELTA 0.125f
#define ggml_vec_dot_q5_0_q8_0_generic ggml_vec_dot_q5_0_q8_0
#define ggml_vec_dot_q5_1_q8_1_generic ggml_vec_dot_q5_1_q8_1
#define ggml_vec_dot_q8_0_q8_0_generic ggml_vec_dot_q8_0_q8_0
+#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_q2_K_q8_K_generic ggml_vec_dot_q2_K_q8_K
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
+#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
#define ggml_vec_dot_tq1_0_q8_K_generic ggml_vec_dot_tq1_0_q8_K
#define ggml_vec_dot_tq2_0_q8_K_generic ggml_vec_dot_tq2_0_q8_K
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
+#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
+#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
#define ggml_vec_dot_iq3_s_q8_K_generic ggml_vec_dot_iq3_s_q8_K
#define ggml_vec_dot_iq1_s_q8_K_generic ggml_vec_dot_iq1_s_q8_K
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
+#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
#define ggml_vec_dot_iq1_m_q8_K_generic ggml_vec_dot_iq1_m_q8_K
#define ggml_vec_dot_iq4_nl_q8_0_generic ggml_vec_dot_iq4_nl_q8_0
#define ggml_vec_dot_iq4_xs_q8_K_generic ggml_vec_dot_iq4_xs_q8_K
+#define ggml_vec_dot_mxfp4_q8_0_generic ggml_vec_dot_mxfp4_q8_0
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
*s = sumf;
}
+void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
+ assert(n % QK_MXFP4 == 0);
+ static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
+
+ const block_mxfp4 * GGML_RESTRICT x = vx;
+ const block_q8_0 * GGML_RESTRICT y = vy;
+
+ const int nb = n / QK_MXFP4;
+
+ int ib = 0;
+ float sumf = 0;
+
+#if defined __ARM_NEON
+ const int8x16_t values = vld1q_s8(kvalues_mxfp4);
+ const uint8x16_t m4b = vdupq_n_u8(0x0f);
+ uint8x16x2_t q4bits;
+ int8x16x4_t q4b;
+ int8x16x4_t q8b;
+ int32x4_t prod_1;
+ int32x4_t prod_2;
+
+ for (; ib + 1 < nb; ib += 2) {
+ q4bits.val[0] = vld1q_u8(x[ib + 0].qs);
+ q4bits.val[1] = vld1q_u8(x[ib + 1].qs);
+ q8b.val[0] = vld1q_s8(y[ib + 0].qs);
+ q8b.val[1] = vld1q_s8(y[ib + 0].qs + 16);
+ q8b.val[2] = vld1q_s8(y[ib + 1].qs);
+ q8b.val[3] = vld1q_s8(y[ib + 1].qs + 16);
+
+ q4b.val[0] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[0], m4b));
+ q4b.val[1] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[0], 4));
+ q4b.val[2] = ggml_vqtbl1q_s8(values, vandq_u8 (q4bits.val[1], m4b));
+ q4b.val[3] = ggml_vqtbl1q_s8(values, vshrq_n_u8(q4bits.val[1], 4));
+
+ prod_1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[0], q8b.val[0]), q4b.val[1], q8b.val[1]);
+ prod_2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q4b.val[2], q8b.val[2]), q4b.val[3], q8b.val[3]);
+
+ sumf +=
+ GGML_E8M0_TO_FP32_HALF(x[ib + 0].e) * GGML_CPU_FP16_TO_FP32(y[ib + 0].d) * vaddvq_s32(prod_1) +
+ GGML_E8M0_TO_FP32_HALF(x[ib + 1].e) * GGML_CPU_FP16_TO_FP32(y[ib + 1].d) * vaddvq_s32(prod_2);
+ }
+
+#endif
+ for (; ib < nb; ++ib) {
+ const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
+ int sumi1 = 0;
+ int sumi2 = 0;
+ for (int j = 0; j < QK_MXFP4/2; ++j) {
+ sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
+ sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4];
+ }
+ sumf += d * (sumi1 + sumi2);
+ }
+ *s = sumf;
+}
+
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
const int qk = QK8_0;
const int nb = n / qk;
}
#if defined(__AVX2__) || defined(__AVX512F__)
+static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
+ const __m256i ax = _mm256_sign_epi8(x, x);
+ const __m256i sy = _mm256_sign_epi8(y, x);
+ return _mm256_maddubs_epi16(ax, sy);
+}
+
// spread 32 bits to 32 bytes { 0x00, 0xFF }
static inline __m256i bytes_from_bits_32(const uint8_t * x) {
uint32_t x32;
return _mm256_set_m128(_mm_set1_ps(GGML_CPU_FP16_TO_FP32(x1) * GGML_CPU_FP16_TO_FP32(y1)),
_mm_set1_ps(GGML_CPU_FP16_TO_FP32(x0) * GGML_CPU_FP16_TO_FP32(y0)));
}
+
+static inline __m256 quad_mx_delta_float(const int8_t x0, const float y0, const int8_t x1, const float y1) {
+ return _mm256_set_m128(_mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x1) * GGML_CPU_FP16_TO_FP32(y1)),
+ _mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x0) * GGML_CPU_FP16_TO_FP32(y0)));
+}
#endif
#elif defined(__SSSE3__)
// horizontally add 4x4 floats
#endif
}
+void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
+ assert(n % QK_MXFP4 == 0);
+ static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
+
+ const block_mxfp4 * GGML_RESTRICT x = vx;
+ const block_q8_0 * GGML_RESTRICT y = vy;
+
+ const int nb = n / QK_MXFP4;
+
+ int ib = 0;
+ float sumf = 0;
+
+#if defined __AVX2__
+
+ const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
+ const __m128i m4b = _mm_set1_epi8(0x0f);
+ const __m256i mone = _mm256_set1_epi16(1);
+
+ __m256 accum1 = _mm256_setzero_ps();
+ __m256 accum2 = _mm256_setzero_ps();
+ for (; ib + 1 < nb; ib += 2) {
+ const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs);
+ const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs);
+ const __m256i q8b_1 = _mm256_loadu_si256((const __m256i *)y[ib + 0].qs);
+ const __m256i q8b_2 = _mm256_loadu_si256((const __m256i *)y[ib + 1].qs);
+ const __m256i q4b_1 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b)),
+ _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b)));
+ const __m256i q4b_2 = MM256_SET_M128I(_mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b)),
+ _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b)));
+ const __m256i p16_1 = mul_add_epi8(q4b_1, q8b_1);
+ const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
+ const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
+ const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);
+ accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 0].e)),
+ _mm256_cvtepi32_ps(p_1), accum1);
+ accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 1].e)),
+ _mm256_cvtepi32_ps(p_2), accum2);
+ }
+
+ sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
+
+#elif defined __AVX__
+ const __m128i values128 = _mm_loadu_si128((const __m128i*)kvalues_mxfp4);
+ const __m128i m4b = _mm_set1_epi8(0x0f);
+
+ __m256 accum = _mm256_setzero_ps();
+ for (; ib + 1 < nb; ib += 2) {
+ const __m128i q4bits_1 = _mm_loadu_si128((const __m128i *)x[ib + 0].qs);
+ const __m128i q4bits_2 = _mm_loadu_si128((const __m128i *)x[ib + 1].qs);
+ const __m128i q8b_1_0 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs);
+ const __m128i q8b_1_1 = _mm_loadu_si128((const __m128i *)y[ib + 0].qs + 1);
+ const __m128i q8b_2_0 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs);
+ const __m128i q8b_2_1 = _mm_loadu_si128((const __m128i *)y[ib + 1].qs + 1);
+
+ const __m128i q4b_1_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_1, m4b));
+ const __m128i q4b_1_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_1, 4), m4b));
+ const __m128i q4b_2_0 = _mm_shuffle_epi8(values128, _mm_and_si128(q4bits_2, m4b));
+ const __m128i q4b_2_1 = _mm_shuffle_epi8(values128, _mm_and_si128(_mm_srli_epi16(q4bits_2, 4), m4b));
+
+ const __m256 p = mul_sum_i8_quad_float(q4b_1_0, q4b_1_1, q4b_2_0, q4b_2_1, q8b_1_0, q8b_1_1, q8b_2_0, q8b_2_1);
+ const __m256 deltas = quad_mx_delta_float(x[ib].e, y[ib].d, x[ib + 1].e, y[ib + 1].d);
+ accum = _mm256_add_ps(_mm256_mul_ps(deltas, p), accum);
+ }
+
+ sumf = hsum_float_8(accum);
+
+#endif
+ for (; ib < nb; ++ib) {
+ const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
+ int sumi1 = 0;
+ int sumi2 = 0;
+ for (int j = 0; j < QK_MXFP4/2; ++j) {
+ sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
+ sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4];
+ }
+ sumf += d * (sumi1 + sumi2);
+ }
+ *s = sumf;
+}
+
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
const int qk = QK8_0;
const int nb = n / qk;
#endif
}
-#if defined(__AVX2__)
-static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
- const __m256i ax = _mm256_sign_epi8(x, x);
- const __m256i sy = _mm256_sign_epi8(y, x);
- return _mm256_maddubs_epi16(ax, sy);
-}
-#endif
-
void ggml_vec_dot_iq1_s_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);
.vec_dot_type = GGML_TYPE_Q8_1,
.nrows = 1,
},
+ [GGML_TYPE_MXFP4] = {
+ .from_float = quantize_row_mxfp4,
+ .vec_dot = ggml_vec_dot_mxfp4_q8_0,
+ .vec_dot_type = GGML_TYPE_Q8_0,
+ .nrows = 1,
+ },
[GGML_TYPE_Q2_K] = {
.from_float = quantize_row_q2_K,
.vec_dot = ggml_vec_dot_q2_K_q8_K,
{
ggml_compute_forward_add(params, tensor);
} break;
+ case GGML_OP_ADD_ID:
+ {
+ ggml_compute_forward_add_id(params, tensor);
+ } break;
case GGML_OP_ADD1:
{
ggml_compute_forward_add1(params, tensor);
} break;
case GGML_OP_FLASH_ATTN_EXT:
{
- ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
+ ggml_compute_forward_flash_attn_ext(params, tensor);
} break;
case GGML_OP_FLASH_ATTN_BACK:
{
case GGML_OP_DUP:
case GGML_OP_CONT:
case GGML_OP_ADD:
+ case GGML_OP_ADD_ID:
case GGML_OP_ADD1:
case GGML_OP_ACC:
{
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_SWIGLU:
+ case GGML_GLU_OP_SWIGLU_OAI:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
{
}
} break;
case GGML_OP_ADD:
+ case GGML_OP_ADD_ID:
case GGML_OP_ADD1:
{
if (ggml_is_quantized(node->src[0]->type)) {
#include "vec.h"
#include <float.h>
+#include <algorithm>
// ggml_compute_forward_dup
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
+ case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
}
}
+// ggml_compute_forward_add_id
+
+static void ggml_compute_forward_add_id_f32(
+ const ggml_compute_params * params,
+ ggml_tensor * dst) {
+
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+ const ggml_tensor * src2 = dst->src[2];
+
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT(src2->type == GGML_TYPE_I32);
+
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
+ GGML_ASSERT(src1->nb[0] == sizeof(float));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nr = ggml_nrows(src0);
+
+ GGML_TENSOR_TERNARY_OP_LOCALS
+
+ GGML_ASSERT( nb0 == sizeof(float));
+ GGML_ASSERT(nb10 == sizeof(float));
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int ir = ir0; ir < ir1; ++ir) {
+ // src0 indices
+ const int i3 = ir/(ne2*ne1);
+ const int i2 = (ir - i3*ne2*ne1)/ne1;
+ const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
+
+ // src1 indices
+ const int i11 = *(int32_t *) ((char *) src2->data + i1*nb20 + i2*nb21);
+
+ GGML_ASSERT(i11 >= 0 && i11 < ne11);
+
+ ggml_vec_add_f32(ne0,
+ (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
+ (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
+ (float *) ((char *) src1->data + i11*nb11));
+ }
+}
+
+void ggml_compute_forward_add_id(
+ const ggml_compute_params * params,
+ ggml_tensor * dst) {
+
+ const ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_add_id_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ABORT("unsupported type for ggml_compute_forward_add_id: %s", ggml_type_name(src0->type));
+ }
+ }
+}
+
// ggml_compute_forward_add1
static void ggml_compute_forward_add1_f32(
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
+ case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
+ case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
}
}
+// ggml_compute_forward_swiglu_oai
+
+static void ggml_compute_forward_swiglu_oai_f32(
+ const ggml_compute_params * params,
+ ggml_tensor * dst) {
+
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+ char * src0_d = (char *) src0->data;
+ char * src1_d = (char *) (src1 ? src1->data : src0->data);
+ const size_t src0_o = src0->nb[1];
+ const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
+
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
+ GGML_ASSERT(ggml_is_contiguous_1(dst));
+
+ if (src1) {
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
+ GGML_ASSERT(src0->type == src1->type);
+ }
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
+ const int nr = ggml_nrows(src0);
+
+ GGML_ASSERT(dst->ne[0] == nc);
+ GGML_ASSERT(ggml_nrows(dst) == nr);
+
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
+ const float alpha = ggml_get_op_params_f32(dst, 2);
+ const float limit = ggml_get_op_params_f32(dst, 3);
+
+ // rows per thread
+ const int dr = (nr + nth - 1)/nth;
+
+ // row range for this thread
+ const int ir0 = dr*ith;
+ const int ir1 = MIN(ir0 + dr, nr);
+
+ for (int i1 = ir0; i1 < ir1; i1++) {
+ float * src0_p = (float *) (src0_d + i1*src0_o);
+ float * src1_p = (float *) (src1_d + i1*src1_o);
+ float * dst_p = (float *) ((char *) dst->data + i1*(dst->nb[1]));
+
+ if (!src1) {
+ src0_p += swapped ? nc : 0;
+ src1_p += swapped ? 0 : nc;
+ }
+
+ for (int k = 0; k < nc; k++) {
+ const float x = std::min(src0_p[k], limit);
+ const float y = std::clamp(src1_p[k], -limit, limit);
+ const float out_glu = x / (1.f + expf(alpha * (-x)));
+ dst_p[k] = out_glu * (y + 1.f);
+ }
+
+#ifndef NDEBUG
+ for (int k = 0; k < nc; k++) {
+ const float x = dst_p[k];
+ GGML_UNUSED(x);
+ assert(!isnan(x));
+ assert(!isinf(x));
+ }
+#endif
+ }
+}
+
+static void ggml_compute_forward_swiglu_oai(
+ const ggml_compute_params * params,
+ ggml_tensor * dst) {
+
+ const ggml_tensor * src0 = dst->src[0];
+
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_swiglu_oai_f32(params, dst);
+ } break;
+ default:
+ {
+ GGML_ABORT("fatal error");
+ }
+ }
+}
+
// ggml_compute_forward_geglu_erf
static void ggml_compute_forward_geglu_erf_f32(
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
+ case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
+ case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
+ case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
+ const ggml_tensor * src2 = dst->src[2];
assert(ggml_is_contiguous(dst));
assert(ggml_are_same_shape(src0, dst));
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
+ // sinks
+ const float * sk = src2 ? (float *)((char *) src2->data) : nullptr;
+
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
float max = -INFINITY;
ggml_vec_max_f32(ne00, &max, wp);
+ // if we have sinks, make a correction as if they were included in the softmax
+ if (sk) {
+ max = MAX(max, sk[i02]);
+ }
+
ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
assert(sum > 0.0);
+ if (sk) {
+ sum += (ggml_float) expf(sk[i02] - max);
+ }
+
sum = 1.0/sum;
ggml_vec_scale_f32(ne00, dp, sum);
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_Q8_1:
+ case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
static void ggml_compute_forward_flash_attn_ext_f16(
const ggml_compute_params * params,
- const ggml_tensor * q,
- const ggml_tensor * k,
- const ggml_tensor * v,
- const ggml_tensor * mask,
ggml_tensor * dst) {
+ const ggml_tensor * q = dst->src[0];
+ const ggml_tensor * k = dst->src[1];
+ const ggml_tensor * v = dst->src[2];
+ const ggml_tensor * mask = dst->src[3];
+ const ggml_tensor * sinks = dst->src[4];
+
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
}
}
+ // sinks
+ if (sinks) {
+ const float s = ((float *)((char *) sinks->data))[h];
+
+ float ms = 1.0f;
+ float vs = 1.0f;
+
+ if (s > M) {
+ ms = expf(M - s);
+ ggml_vec_scale_f32(DV, VKQ32, ms);
+ } else {
+ vs = expf(s - M);
+ }
+
+ S = S*ms + vs;
+ }
+
// V /= S
const float S_inv = 1.0f/S;
ggml_vec_scale_f32(DV, VKQ32, S_inv);
void ggml_compute_forward_flash_attn_ext(
const ggml_compute_params * params,
- const ggml_tensor * q,
- const ggml_tensor * k,
- const ggml_tensor * v,
- const ggml_tensor * mask,
ggml_tensor * dst) {
switch (dst->op_params[3]) {
case GGML_PREC_DEFAULT:
case GGML_PREC_F32:
{
// uses F32 accumulators
- ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
+ ggml_compute_forward_flash_attn_ext_f16(params, dst);
} break;
default:
{
{
ggml_compute_forward_swiglu(params, dst);
} break;
+ case GGML_GLU_OP_SWIGLU_OAI:
+ {
+ ggml_compute_forward_swiglu_oai(params, dst);
+ } break;
case GGML_GLU_OP_GEGLU_ERF:
{
ggml_compute_forward_geglu_erf(params, dst);
void ggml_compute_forward_dup(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_add(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_add_id(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_add1(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_acc(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_sum(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
-void ggml_compute_forward_flash_attn_ext(
- const struct ggml_compute_params * params,
- const struct ggml_tensor * q,
- const struct ggml_tensor * k,
- const struct ggml_tensor * v,
- const struct ggml_tensor * mask,
- struct ggml_tensor * dst);
+void ggml_compute_forward_flash_attn_ext(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_flash_attn_back(
const struct ggml_compute_params * params,
const bool masked,
quantize_row_q8_1_ref(x, y, k);
}
+void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k) {
+ quantize_row_mxfp4_ref(x, y, k);
+}
+
//
// 2-6 bit quantization in super-blocks
//
*s = sumf;
}
+void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
+ assert(nrc == 1);
+ UNUSED(nrc);
+ UNUSED(bx);
+ UNUSED(by);
+ UNUSED(bs);
+ assert(n % QK_MXFP4 == 0);
+ static_assert(QK_MXFP4 == QK8_0, "QK_MXFP4 and QK8_0 must be the same");
+
+ const block_mxfp4 * GGML_RESTRICT x = vx;
+ const block_q8_0 * GGML_RESTRICT y = vy;
+
+ const int nb = n / QK_MXFP4;
+
+ int ib = 0;
+ float sumf = 0;
+
+ for (; ib < nb; ++ib) {
+ const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e);
+
+ int sumi1 = 0;
+ int sumi2 = 0;
+ for (int j = 0; j < QK_MXFP4/2; ++j) {
+ sumi1 += y[ib].qs[j + 0] * kvalues_mxfp4[x[ib].qs[j] & 0xf];
+ sumi2 += y[ib].qs[j + QK_MXFP4/2] * kvalues_mxfp4[x[ib].qs[j] >> 4];
+ }
+ sumf += d * (sumi1 + sumi2);
+ }
+ *s = sumf;
+}
+
void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc) {
const int qk = QK8_0;
const int nb = n / qk;
void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+void quantize_row_mxfp4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
+
void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q5_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q5_1_q8_1_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q8_0_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
+void ggml_vec_dot_mxfp4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
void ggml_vec_dot_tq1_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_tq2_0_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
+
void ggml_vec_dot_q2_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q3_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q4_K_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const ggml_fp16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
-inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
+
+inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) {
+ int i = 0;
+#if defined(__AVX2__)
+ for (; i + 7 < n; i += 8) {
+ __m256 vx = _mm256_loadu_ps(x + i);
+ __m256 vy = _mm256_loadu_ps(y + i);
+ __m256 vz = _mm256_add_ps(vx, vy);
+ _mm256_storeu_ps(z + i, vz);
+ }
+#endif
+ for (; i < n; ++i) {
+ z[i] = x[i] + y[i];
+ }
+}
+
inline static void ggml_vec_add_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) {
for (int i = 0; i < n; ++i) {
z[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(x[i]) + GGML_CPU_FP16_TO_FP32(y[i]));
inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
for (int i = 0; i < n; ++i) {
- float v = GGML_CPU_FP16_TO_FP32(x[i]);
- float w = GGML_CPU_FP16_TO_FP32(g[i]);
- y[i] = GGML_CPU_FP32_TO_FP16((v/(1.0f + expf(-v))) * w);
+ float xi = GGML_CPU_FP16_TO_FP32(x[i]);
+ float gi = GGML_CPU_FP16_TO_FP32(g[i]);
+ y[i] = GGML_CPU_FP32_TO_FP16((xi/(1.0f + expf(-xi))) * gi);
}
}
--- /dev/null
+#include "add-id.cuh"
+
+static __global__ void add_id_kernel(
+ const float * src0, const float * src1, const int32_t * src2, float * dst,
+ int64_t ne0, int64_t ne1,
+ size_t nb01, size_t nb02,
+ size_t nb11,
+ size_t nb21
+ ) {
+
+ const int64_t i1 = blockIdx.x;
+ const int64_t i2 = blockIdx.y;
+
+ const int i11 = *(int32_t *) ((char *) src2 + i1*sizeof(int32_t) + i2*nb21);
+
+ const size_t nb1 = ne0 * sizeof(float);
+ const size_t nb2 = ne1 * nb1;
+
+ float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2);
+ const float * src0_row = (const float *)((char *)src0 + i1*nb01 + i2*nb02);
+ const float * src1_row = (const float *)((char *)src1 + i11*nb11);
+
+ for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
+ dst_row[i0] = src0_row[i0] + src1_row[i0];
+ }
+}
+
+void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+ const ggml_tensor * src2 = dst->src[2];
+
+ GGML_TENSOR_TERNARY_OP_LOCALS
+
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT(src2->type == GGML_TYPE_I32);
+
+ GGML_ASSERT(nb00 == sizeof(float));
+ GGML_ASSERT(nb10 == sizeof(float));
+ GGML_ASSERT(nb20 == sizeof(int32_t));
+
+ const float * src0_d = (const float *)src0->data;
+ const float * src1_d = (const float *)src1->data;
+ const int32_t * src2_d = (const int32_t *)src2->data;
+ float * dst_d = (float *)dst->data;
+
+ int threads = std::min((int)ne00, 768); // cols
+ dim3 blocks(ne01, ne02); // n_experts_used, n_tokens
+ add_id_kernel<<<blocks, threads, 0, ctx.stream()>>>(
+ src0_d, src1_d, src2_d, dst_d,
+ ne0, ne1,
+ nb01, nb02,
+ nb11,
+ nb21
+ );
+}
--- /dev/null
+#include "common.cuh"
+
+void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
#pragma once
#include "ggml.h"
+#include "ggml-impl.h"
#include "ggml-cuda.h"
#include <cstdint>
#endif // defined(GGML_USE_HIP)
}
+static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
+#if CUDART_VERSION >= 12080
+ const nv_bfloat16 e = __nv_cvt_e8m0_to_bf16raw(x);
+ return (float) e;
+#else
+ uint32_t bits;
+ if (x == 0) {
+ bits = 0x00400000;
+ } else {
+ bits = (uint32_t) x << 23;
+ }
+
+ float result;
+ memcpy(&result, &bits, sizeof(float));
+ return result;
+#endif // CUDART_VERSION >= 12050
+}
+
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, dfloat2 & v);
static __device__ __forceinline__ float get_alibi_slope(
static constexpr int qi = QI8_0;
};
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_MXFP4> {
+ static constexpr int qk = QK_MXFP4;
+ static constexpr int qr = QR_MXFP4;
+ static constexpr int qi = QI_MXFP4;
+};
+
template<>
struct ggml_cuda_type_traits<GGML_TYPE_Q2_K> {
static constexpr int qk = QK_K;
}
}
+template<typename dst_t>
+static __global__ void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restrict__ yy) {
+
+ const int64_t i = blockIdx.x;
+ const block_mxfp4 * x = (const block_mxfp4 *) vx + i*(QK_K/QK_MXFP4);
+
+ const int64_t tid = threadIdx.x;
+ const int64_t il = tid/8; // 0...3
+ const int64_t ib = tid%8; // 0...7
+ dst_t * y = yy + i*QK_K + 32*ib + 4*il;
+ const uint8_t * q4 = x[ib].qs + 4*il;
+ const float d = ggml_cuda_e8m0_to_fp32(x[ib].e);
+ for (int j = 0; j < 4; ++j) {
+ y[j+ 0] = d * kvalues_mxfp4[q4[j] & 0xf]*0.5f;
+ y[j+16] = d * kvalues_mxfp4[q4[j] >> 4]*0.5f;
+ }
+}
+
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static void dequantize_block_cuda(const void * vx, dst_t * y,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
}
+template<typename dst_t>
+static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t k, cudaStream_t stream) {
+ const int nb = (k + QK_K - 1) / QK_K;
+ dequantize_block_mxfp4<<<nb, 32, 0, stream>>>(vx, y);
+}
+
template <typename src_t, typename dst_t>
static __global__ void convert_unary(
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
return dequantize_row_iq4_xs_cuda;
case GGML_TYPE_IQ3_S:
return dequantize_row_iq3_s_cuda;
+ case GGML_TYPE_MXFP4:
+ return dequantize_row_mxfp4_cuda;
case GGML_TYPE_F32:
return convert_unary_cont_cuda<float>;
case GGML_TYPE_BF16:
return dequantize_row_iq4_xs_cuda;
case GGML_TYPE_IQ3_S:
return dequantize_row_iq3_s_cuda;
+ case GGML_TYPE_MXFP4:
+ return dequantize_row_mxfp4_cuda;
case GGML_TYPE_F16:
return convert_unary_cont_cuda<half>;
case GGML_TYPE_BF16:
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
+ const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
GGML_ASSERT(V || is_mla);
- const ggml_tensor * mask = dst->src[3];
+ const ggml_tensor * mask = dst->src[3];
+ const ggml_tensor * sinks = dst->src[4];
ggml_tensor * KQV = dst;
K_data,
V_data,
mask ? ((const char *) mask->data) : nullptr,
+ sinks ? ((const char *) sinks->data) : nullptr,
KV_max.ptr,
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
+ const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
// kb0 == k start index when in the output tile.
int kb0_start = kbc % iter_k;
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
+
while (kbc < kbc_stop && kb0_stop == iter_k) {
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
#else
- GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+ GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
+ const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
}
}
#else
- GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+ GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
+ const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
return;
#endif // FP16_MMA_AVAILABLE
if (use_logit_softcap && !(D == 128 || D == 256)) {
- GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+ GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
+ const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
K += nb13*sequence + nb12*(head / gqa_ratio);
V += nb23*sequence + nb22*(head / gqa_ratio);
- const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
+ const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
+ const float * sinksf = (const float *) (sinks);
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);
half2 * KQ2 = (half2 *) KQ;
half kqmax[ncols];
+ half kqsum[ncols];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
kqmax[j] = -HALF_MAX_HALF;
+ kqsum[j] = 0.0f;
}
- half kqsum[ncols] = {0.0f};
__shared__ half kqmax_shared[ncols][WARP_SIZE];
__shared__ half kqsum_shared[ncols][WARP_SIZE];
__syncthreads();
}
+ if (sinksf && blockIdx.y == 0) {
+ const half sink = __float2half(sinksf[head]);
+
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ if (threadIdx.x == 0) {
+ kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink);
+ }
+ }
+
+ __syncthreads();
+
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ half kqmax_new_j = kqmax_shared[j][threadIdx.x];
+ kqmax_new_j = warp_reduce_max(kqmax_new_j);
+
+ const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
+ kqmax[j] = kqmax_new_j;
+
+ const half val = hexp(sink - kqmax[j]);
+ kqsum[j] = kqsum[j]*KQ_max_scale;
+
+ if (tid == 0) {
+ kqsum[j] += val;
+ }
+
+ VKQ[j] *= __half2half2(KQ_max_scale);
+ }
+
+ __syncthreads();
+ }
+
#pragma unroll
for (int j = 0; j < ncols; ++j) {
kqsum[j] = warp_reduce_sum((float)kqsum[j]);
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
}
#else
- GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+ GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta);
GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
+ const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
K += nb13*sequence + nb12*(head / gqa_ratio);
V += nb23*sequence + nb22*(head / gqa_ratio);
- const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
+ const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
+ const float * sinksf = (const float *) (sinks);
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
}
float kqmax[ncols];
+ float kqsum[ncols];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
kqmax[j] = -FLT_MAX/2.0f;
+ kqsum[j] = 0.0f;
}
- float kqsum[ncols] = {0.0f};
__shared__ float kqmax_shared[ncols][WARP_SIZE];
__shared__ float kqsum_shared[ncols][WARP_SIZE];
__syncthreads();
}
+ if (sinksf && blockIdx.y == 0) {
+ const float sink = sinksf[head];
+
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ if (threadIdx.x == 0) {
+ kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink);
+ }
+ }
+
+ __syncthreads();
+
+#pragma unroll
+ for (int j = 0; j < ncols; ++j) {
+ float kqmax_new_j = kqmax_shared[j][threadIdx.x];
+ kqmax_new_j = warp_reduce_max(kqmax_new_j);
+
+ const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j);
+ kqmax[j] = kqmax_new_j;
+
+ const float val = expf(sink - kqmax[j]);
+ kqsum[j] = kqsum[j]*KQ_max_scale;
+
+ if (tid == 0) {
+ kqsum[j] += val;
+ }
+
+ VKQ[j] *= KQ_max_scale;
+ }
+
+ __syncthreads();
+ }
+
#pragma unroll
for (int j = 0; j < ncols; ++j) {
kqsum[j] = warp_reduce_sum(kqsum[j]);
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
+ const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
dst_meta[j_dst_unrolled] = dst_meta_val;
}
#else
- GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
+ GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
}
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
- const ggml_tensor * KQV = dst;
- const ggml_tensor * Q = dst->src[0];
- const ggml_tensor * K = dst->src[1];
- const ggml_tensor * V = dst->src[2];
- const ggml_tensor * mask = dst->src[3];
+ const ggml_tensor * KQV = dst;
+ const ggml_tensor * Q = dst->src[0];
+ const ggml_tensor * K = dst->src[1];
+ const ggml_tensor * V = dst->src[2];
+ const ggml_tensor * mask = dst->src[3];
+ const ggml_tensor * sinks = dst->src[4];
ggml_cuda_set_device(ctx.device);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
+ // TODO: currently only vec implementation for sinks is supported [TAG_ATTN_SINKS]
+ if (sinks) {
+ if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
+ ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
+ } else {
+ ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
+ }
+ return;
+ }
+
#if defined(GGML_HIP_ROCWMMA_FATTN)
if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) {
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
#include "ggml-cuda/common.cuh"
#include "ggml-cuda/acc.cuh"
+#include "ggml-cuda/add-id.cuh"
#include "ggml-cuda/arange.cuh"
#include "ggml-cuda/argmax.cuh"
#include "ggml-cuda/argsort.cuh"
case GGML_OP_ADD1: // TODO: more efficient implementation
ggml_cuda_op_add(ctx, dst);
break;
+ case GGML_OP_ADD_ID:
+ ggml_cuda_op_add_id(ctx, dst);
+ break;
case GGML_OP_SUB:
ggml_cuda_op_sub(ctx, dst);
break;
case GGML_GLU_OP_SWIGLU:
ggml_cuda_op_swiglu(ctx, dst);
break;
+ case GGML_GLU_OP_SWIGLU_OAI:
+ ggml_cuda_op_swiglu_oai(ctx, dst);
+ break;
case GGML_GLU_OP_GEGLU_ERF:
ggml_cuda_op_geglu_erf(ctx, dst);
break;
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
+ const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
+ const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
+ const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
#endif
}
- if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1 && (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) && (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true)) {
+ if (node->op == GGML_OP_ADD &&
+ node->src[1] && node->src[1]->ne[1] > 1 &&
+ (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
+ (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
+ strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
+ strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
+ strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0) {
// disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
// by means of matching node names. See
// https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_SWIGLU:
+ case GGML_GLU_OP_SWIGLU_OAI:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
return ggml_is_contiguous_1(op->src[0]);
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
+ case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
case GGML_OP_ADD:
+ case GGML_OP_ADD_ID:
case GGML_OP_ADD1:
case GGML_OP_SUB:
case GGML_OP_MUL:
const int gqa_ratio = op->src[0]->ne[2] / op->src[1]->ne[2];
return op->src[1]->ne[0] == 576 && op->src[2]->ne[0] == 512 && op->src[3] && gqa_ratio % 16 == 0;
}
+ // TODO: more general-purpose attention sink support [TAG_ATTN_SINKS]
+ if (op->src[4] && op->src[0]->ne[0] != 64 && op->src[0]->ne[0] != 128) { // currently only sinks for head_size 64 and 128 are supported
+ return false;
+ }
if (op->src[0]->ne[0] == 192) {
return false;
}
#include "im2col.cuh"
-#define MIN(a, b) (a) < (b) ? (a) : (b)
-
#define MAX_GRIDDIM_Z 65535
template <typename T>
dst[offset_dst] = x[offset_src + iih * IW + iiw];
}
}
+
+ GGML_UNUSED(IC);
+ GGML_UNUSED(KH);
}
// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
case GGML_TYPE_Q8_0:
mul_mat_q_case<GGML_TYPE_Q8_0>(ctx, args, stream);
break;
+ case GGML_TYPE_MXFP4:
+ mul_mat_q_case<GGML_TYPE_MXFP4>(ctx, args, stream);
+ break;
case GGML_TYPE_Q2_K:
mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream);
break;
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
+ case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
return MMQ_Q8_1_DS_LAYOUT_DS4;
case GGML_TYPE_Q8_0:
return MMQ_Q8_1_DS_LAYOUT_D4;
+ case GGML_TYPE_MXFP4:
+ return MMQ_Q8_1_DS_LAYOUT_D4;
case GGML_TYPE_Q2_K:
return MMQ_Q8_1_DS_LAYOUT_D2S6;
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q5_0: return MMQ_DP4A_TXS_Q8_0;
case GGML_TYPE_Q5_1: return MMQ_DP4A_TXS_Q8_1;
case GGML_TYPE_Q8_0: return MMQ_DP4A_TXS_Q8_0;
+ case GGML_TYPE_MXFP4: return MMQ_DP4A_TXS_Q8_1;
case GGML_TYPE_Q2_K: return MMQ_DP4A_TXS_Q2_K;
case GGML_TYPE_Q3_K: return MMQ_DP4A_TXS_Q3_K;
case GGML_TYPE_Q4_K: return MMQ_DP4A_TXS_Q4_K;
case GGML_TYPE_Q5_0: return MMQ_MMA_TILE_X_K_Q8_0;
case GGML_TYPE_Q5_1: return MMQ_MMA_TILE_X_K_Q8_1;
case GGML_TYPE_Q8_0: return MMQ_MMA_TILE_X_K_Q8_0;
+ case GGML_TYPE_MXFP4: return MMQ_MMA_TILE_X_K_Q8_1;
case GGML_TYPE_Q2_K: return MMQ_MMA_TILE_X_K_Q2_K;
case GGML_TYPE_Q3_K: return MMQ_MMA_TILE_X_K_Q3_K;
case GGML_TYPE_Q4_K: return MMQ_MMA_TILE_X_K_Q8_1;
}
}
+template <int mmq_y, bool need_check> static __device__ __forceinline__ void load_tiles_mxfp4(
+ const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
+ constexpr int nwarps = mmq_get_nwarps_device();
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
+#else
+ constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y);
+ int * x_qs = (int *) x_tile;
+ float * x_df = (float *) (x_qs + txs.qs);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+
+ constexpr int threads_per_row = MMQ_ITER_K / (4 * QR_MXFP4);
+ constexpr int nrows = warp_size / threads_per_row;
+ const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x;
+ const int kbx = txi / QI_MXFP4;
+ const int kqsx = txi % QI_MXFP4;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) {
+ int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row);
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbx;
+
+ const int aux_q4 = get_int_b1(bxi->qs, kqsx);
+ const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
+ const int k0 = kbx * (2 * QI_MXFP4) + kqsx;
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + 0] = v.x;
+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + k0 + QI_MXFP4] = v.y;
+#else
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x;
+ x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI_MXFP4] = v.y;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+ }
+
+ constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI_MXFP4;
+ constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row;
+ const int kbxd = threadIdx.x % blocks_per_tile_x_row;
+
+#pragma unroll
+ for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) {
+ int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row;
+
+ if (need_check) {
+ i = min(i, i_max);
+ }
+
+ const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i*stride + kbxd;
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+ x_df[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
+#else
+ x_df[i*(MMQ_TILE_NE_K/QI_MXFP4) + i/QI_MXFP4 + kbxd] = ggml_cuda_e8m0_to_fp32(bxi->e)*0.5f;
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
+ }
+}
+
template <int mmq_x, int mmq_y>
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbx;
const int aux_q4 = get_int_b2(bxi->qs, kqsx);
- const int2 v = get_int_from_table_16(aux_q4);
+ const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
const int k0 = kbx * (2 * QI4_NL) + kqsx;
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride;
const int aux_q4 = get_int_b4(bxi->qs, kqsx);
- const int2 v = get_int_from_table_16(aux_q4);
+ const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
const int k0 = 8 * (kqsx / 4) + kqsx % 4;
#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
};
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
+ static constexpr int vdr = VDR_MXFP4_Q8_1_MMQ;
+ static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, need_check>;
+ static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
+ static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
+};
+
template <int mmq_x, int mmq_y, bool need_check>
struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {
static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
+extern DECL_MMQ_CASE(GGML_TYPE_MXFP4);
extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
case GGML_TYPE_Q5_0: return vec_dot_q5_0_q8_1;
case GGML_TYPE_Q5_1: return vec_dot_q5_1_q8_1;
case GGML_TYPE_Q8_0: return vec_dot_q8_0_q8_1;
+ case GGML_TYPE_MXFP4: return vec_dot_mxfp4_q8_1;
case GGML_TYPE_Q2_K: return vec_dot_q2_K_q8_1;
case GGML_TYPE_Q3_K: return vec_dot_q3_K_q8_1;
case GGML_TYPE_Q4_K: return vec_dot_q4_K_q8_1;
case GGML_TYPE_Q5_0: return VDR_Q5_0_Q8_1_MMVQ;
case GGML_TYPE_Q5_1: return VDR_Q5_1_Q8_1_MMVQ;
case GGML_TYPE_Q8_0: return VDR_Q8_0_Q8_1_MMVQ;
+ case GGML_TYPE_MXFP4: return VDR_MXFP4_Q8_1_MMVQ;
case GGML_TYPE_Q2_K: return VDR_Q2_K_Q8_1_MMVQ;
case GGML_TYPE_Q3_K: return VDR_Q3_K_Q8_1_MMVQ;
case GGML_TYPE_Q4_K: return VDR_Q4_K_Q8_1_MMVQ;
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
stream);
break;
+ case GGML_TYPE_MXFP4:
+ mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_MXFP4>
+ (vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+ nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+ nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
+ stream);
+ break;
case GGML_TYPE_Q2_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
#endif // __clang__
template <bool use_shared, int ncols_template, int block_size_template, typename T>
static __global__ void soft_max_f32(
- const float * x, const T * mask, float * dst, const soft_max_params p) {
+ const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params p) {
const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
const int tid = threadIdx.x;
// shared memory buffer to cache values between iterations:
float * vals = use_shared ? buf_iw + WARP_SIZE : dst;
- float max_val = -INFINITY;
+ float max_val = sinks ? sinks[i02] : -INFINITY;
#pragma unroll
for (int col0 = 0; col0 < ncols; col0 += block_size) {
tmp = warp_reduce_sum(tmp);
}
+ if (sinks) {
+ tmp += expf(sinks[i02] - max_val);
+ }
+
const float inv_sum = 1.0f / tmp;
#pragma unroll
}
template<int... Ns, typename T>
-static void launch_soft_max_kernels(const float * x, const T * mask, float * dst,
+static void launch_soft_max_kernels(const float * x, const T * mask, const float * sinks, float * dst,
const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
{
const int id = ggml_cuda_get_device();
if (p.ncols == ncols) {
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
- (x, mask, dst, p);
+ (x, mask, sinks, dst, p);
return true;
}
return false;
//default case
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
- soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, dst, p);
+ soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, sinks, dst, p);
}
template<typename T>
-static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) {
+static void soft_max_f32_cuda(const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params & params, cudaStream_t stream) {
int nth = WARP_SIZE;
const int64_t ncols_x = params.ncols;
if (nbytes_shared <= smpbo) {
- launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, dst, params, stream, block_dims, block_nums, nbytes_shared);
+ launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared);
} else {
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
- soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, params);
+ soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, sinks, dst, params);
}
}
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
+ const ggml_tensor * src2 = dst->src[2];
const float * src0_d = (const float *) src0->data;
const void * src1_d = src1 ? (const void *) src1->data : nullptr;
+ const void * src2_d = src2 ? (const void *) src2->data : nullptr;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
params.m1 = m1;
if (use_f16) {
- soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, params, stream);
+ soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream);
} else {
- soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, params, stream);
+ soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream);
}
}
--- /dev/null
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_MXFP4);
ggml_cuda_op_unary_gated<op_gelu_quick>(ctx, dst);
}
+// swiglu_oai
+
+template <typename T>
+static __global__ void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, float alpha, float limit) {
+ const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x;
+
+ if (i >= k) {
+ return;
+ }
+
+ // perform base op and multiply with gate (either offset in same tensor or a separate one)
+ const int64_t j0 = (i / n) * o0 + (i % n);
+ const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
+
+ float xi = x[j0];
+ float gi = g[j1];
+ xi = fminf(xi, limit);
+ gi = fmaxf(fminf(gi, limit), -limit);
+
+ float out_glu = xi / (1.0f + expf(-xi * alpha));
+ out_glu = out_glu * (1.0f + gi);
+
+ dst[i] = out_glu;
+}
+
+template <typename T>
+static void swiglu_oai_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream) {
+ const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE;
+ swiglu_oai_kernel<<<num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream>>>(x, g, dst, k, n, o0, o1, alpha, limit);
+}
+
+void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+ void * src0_d = src0->data;
+ void * src1_d = src1 ? src1->data : src0->data;
+ const int64_t src0_o = src0->nb[1];
+ const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
+ void * dst_d = dst->data;
+ const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
+ cudaStream_t stream = ctx.stream();
+
+ GGML_ASSERT(ggml_is_contiguous_1(src0));
+ GGML_ASSERT(src0->nb[0] == ggml_element_size(src0));
+ GGML_ASSERT(ggml_is_contiguous(dst));
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(src0->type == dst->type);
+ GGML_ASSERT(dst->ne[0] == nc);
+ GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
+
+ if (src1) {
+ GGML_ASSERT(ggml_is_contiguous_1(src1));
+ GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
+ GGML_ASSERT(src1->ne[0] == nc);
+ GGML_ASSERT(src0->type == src1->type);
+ }
+
+ //const int32_t swapped = ((const int32_t *) dst->op_params)[1];
+ const int32_t swapped = ggml_get_op_params_i32(dst, 1);
+ const float alpha = ggml_get_op_params_f32(dst, 2);
+ const float limit = ggml_get_op_params_f32(dst, 3);
+
+ float * src0_p = (float *) src0_d;
+ float * src1_p = (float *) src1_d;
+
+ if (!src1) {
+ src0_p += swapped ? nc : 0;
+ src1_p += swapped ? 0 : nc;
+ }
+
+ swiglu_oai_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream);
+}
+
/* silu_back */
static __device__ __forceinline__ float op_silu_back(float grad, float x) {
void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
+
void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
#pragma once
#include "common.cuh"
+
#include <cstdint>
+static __device__ __forceinline__ int get_int_b1(const void * x, const int & i32) {
+ const uint8_t * x8 = (const uint8_t *) x;
+
+ int x32 = x8[4*i32 + 0] << 0;
+ x32 |= x8[4*i32 + 1] << 8;
+ x32 |= x8[4*i32 + 2] << 16;
+ x32 |= x8[4*i32 + 3] << 24;
+
+ return x32;
+}
+
static __device__ __forceinline__ int get_int_b2(const void * x, const int & i32) {
const uint16_t * x16 = (const uint16_t *) x; // assume at least 2 byte alignment
return ((const int *) x)[i32]; // assume at least 4 byte alignment
}
+static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * table) {
+ const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
+ const int8_t * q0_8 = (const int8_t *) &q0_32;
+ const char4 val0_8 = make_char4(
+ table[q0_8[0]], table[q0_8[1]], table[q0_8[2]], table[q0_8[3]]);
+
+ const int q1_32 = (q4 >> 4) & 0x0F0F0F0F;
+ const int8_t * q1_8 = (const int8_t *) &q1_32;
+ const char4 val1_8 = make_char4(
+ table[q1_8[0]], table[q1_8[1]], table[q1_8[2]], table[q1_8[3]]);
+
+ return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
+}
+
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
return d8_1*sumf;
}
+#define VDR_MXFP4_Q8_1_MMVQ 2
+#define VDR_MXFP4_Q8_1_MMQ 4
+
+static __device__ __forceinline__ float vec_dot_mxfp4_q8_1(
+ const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
+
+ const block_mxfp4 * bq4 = (const block_mxfp4 *) vbq + kbx;
+
+ const int * q8 = (const int *) bq8_1->qs + iqs;
+
+ int sumi = 0;
+#pragma unroll
+ for (int l = 0; l < VDR_MXFP4_Q8_1_MMVQ; ++l) {
+ const int aux_q4 = get_int_b1(bq4->qs, iqs + l);
+ const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
+
+ sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
+ sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
+ }
+
+ const float d = ggml_cuda_e8m0_to_fp32(bq4->e) * 0.5f * __low2float(bq8_1->ds);
+ return d * sumi;
+}
+
#define VDR_Q2_K_Q8_1_MMVQ 1
#define VDR_Q2_K_Q8_1_MMQ 4
return d * ((sumi[0] + sumf[0]) * sc0 + (sumi[1] + sumf[1]) * sc1);
}
-static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4) {
- const int q0_32 = (q4 >> 0) & 0x0F0F0F0F;
- const int8_t * q0_8 = (const int8_t *) &q0_32;
- const char4 val0_8 = make_char4(
- kvalues_iq4nl[q0_8[0]], kvalues_iq4nl[q0_8[1]], kvalues_iq4nl[q0_8[2]], kvalues_iq4nl[q0_8[3]]);
-
- const int q1_32 = (q4 >> 4) & 0x0F0F0F0F;
- const int8_t * q1_8 = (const int8_t *) &q1_32;
- const char4 val1_8 = make_char4(
- kvalues_iq4nl[q1_8[0]], kvalues_iq4nl[q1_8[1]], kvalues_iq4nl[q1_8[2]], kvalues_iq4nl[q1_8[3]]);
-
- return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
-}
-
#define VDR_IQ4_NL_Q8_1_MMVQ 2
#define VDR_IQ4_NL_Q8_1_MMQ 4
#pragma unroll
for (int l = 0; l < VDR_Q4_0_Q8_1_MMVQ; ++l) {
const int aux_q4 = get_int_b2(bq4->qs, iqs + l);
- const int2 v = get_int_from_table_16(aux_q4);
+ const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
sumi = ggml_cuda_dp4a(v.x, q8[l + 0], sumi);
sumi = ggml_cuda_dp4a(v.y, q8[l + 4], sumi);
#pragma unroll
for (int j = 0; j < 4; ++j) {
const int aux_q4 = get_int_b4(bq4->qs, iqs + j);
- const int2 v = get_int_from_table_16(aux_q4);
+ const int2 v = get_int_from_table_16(aux_q4, kvalues_iq4nl);
const int u0 = get_int_b4(bq8_1[iqs/4].qs, j + 0);
const int u1 = get_int_b4(bq8_1[iqs/4].qs, j + 4);
#include <cuda_bf16.h>
#include <cuda_fp16.h>
+#if CUDART_VERSION >= 12050
+#include <cuda_fp8.h>
+#endif // CUDART_VERSION >= 12050
+
#if CUDART_VERSION < 11020
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
#define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
+static inline float ggml_e8m0_to_fp32(uint8_t x) {
+ uint32_t bits; // Stores the raw bit representation of the float
+
+ // Handle special case for minimum exponent (denormalized float)
+ if (x == 0) {
+ // Bit pattern for 2^(-127):
+ // - Sign bit: 0 (positive)
+ // - Exponent: 0 (denormalized number)
+ // - Mantissa: 0x400000 (0.5 in fractional form)
+ // Value = 0.5 * 2^(-126) = 2^(-127)
+ bits = 0x00400000;
+ }
+ // note: disabled as we don't need to handle NaNs
+ //// Handle special case for NaN (all bits set)
+ //else if (x == 0xFF) {
+ // // Standard quiet NaN pattern:
+ // // - Sign bit: 0
+ // // - Exponent: all 1s (0xFF)
+ // // - Mantissa: 0x400000 (quiet NaN flag)
+ // bits = 0x7FC00000;
+ //}
+ // Normalized values (most common case)
+ else {
+ // Construct normalized float by shifting exponent into position:
+ // - Exponent field: 8 bits (positions 30-23)
+ // - Mantissa: 0 (implicit leading 1)
+ // Value = 2^(x - 127)
+ bits = (uint32_t) x << 23;
+ }
+
+ float result; // Final float value
+ // Safely reinterpret bit pattern as float without type-punning issues
+ memcpy(&result, &bits, sizeof(float));
+ return result;
+}
+
+// Equal to ggml_e8m0_to_fp32/2
+// Useful with MXFP4 quantization since the E0M2 values are doubled
+static inline float ggml_e8m0_to_fp32_half(uint8_t x) {
+ uint32_t bits;
+
+ // For x < 2: use precomputed denormal patterns
+ if (x < 2) {
+ // 0x00200000 = 2^(-128), 0x00400000 = 2^(-127)
+ bits = 0x00200000 << x;
+ }
+ // For x >= 2: normalized exponent adjustment
+ else {
+ // 0.5 * 2^(x-127) = 2^(x-128) = normalized with exponent (x-1)
+ bits = (uint32_t)(x - 1) << 23;
+ }
+ // Note: NaNs are not handled here
+
+ float result;
+ memcpy(&result, &bits, sizeof(float));
+ return result;
+}
+
+#define GGML_E8M0_TO_FP32(x) ggml_e8m0_to_fp32(x)
+#define GGML_E8M0_TO_FP32_HALF(x) ggml_e8m0_to_fp32_half(x)
+
/**
* Converts brain16 to float32.
*
#define N_R0_Q8_0 4
#define N_SG_Q8_0 2
+#define N_R0_MXFP4 2
+#define N_SG_MXFP4 2
+
#define N_R0_Q2_K 4
#define N_SG_Q2_K 2
uint64_t o1[8];
} ggml_metal_kargs_bin;
+typedef struct {
+ int64_t ne0;
+ int64_t ne1;
+ size_t nb01;
+ size_t nb02;
+ size_t nb11;
+ size_t nb21;
+} ggml_metal_kargs_add_id;
+
typedef struct {
int32_t ne00;
int32_t ne01;
uint64_t nb1;
int32_t i00;
int32_t i10;
+ float alpha;
+ float limit;
} ggml_metal_kargs_glu;
typedef struct {
GGML_METAL_KERNEL_TYPE_MUL_ROW_C4,
GGML_METAL_KERNEL_TYPE_DIV,
GGML_METAL_KERNEL_TYPE_DIV_ROW_C4,
+ GGML_METAL_KERNEL_TYPE_ADD_ID,
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
GGML_METAL_KERNEL_TYPE_REPEAT_I32,
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1,
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0,
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4,
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K,
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K,
GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16,
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16,
GGML_METAL_KERNEL_TYPE_REGLU,
GGML_METAL_KERNEL_TYPE_GEGLU,
GGML_METAL_KERNEL_TYPE_SWIGLU,
+ GGML_METAL_KERNEL_TYPE_SWIGLU_OAI,
GGML_METAL_KERNEL_TYPE_GEGLU_ERF,
GGML_METAL_KERNEL_TYPE_GEGLU_QUICK,
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, mul_row_c4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, div_row_c4, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ID, add_id, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_I32, repeat_i32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4, get_rows_mxfp4, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, mul_mv_mxfp4_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2, mul_mv_ext_mxfp4_f32_r1_2, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3, mul_mv_ext_mxfp4_f32_r1_3, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4, mul_mv_ext_mxfp4_f32_r1_4, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5, mul_mv_ext_mxfp4_f32_r1_5, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2, mul_mv_ext_q4_K_f32_r1_2, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_3, mul_mv_ext_q4_K_f32_r1_3, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_4, mul_mv_ext_q4_K_f32_r1_4, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, has_simdgroup_reduction);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32, mul_mv_id_mxfp4_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32, mul_mm_mxfp4_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16, mul_mm_id_q5_0_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16, mul_mm_id_q5_1_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16, mul_mm_id_q8_0_f16, has_simdgroup_mm);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16, mul_mm_id_mxfp4_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16, mul_mm_id_q2_K_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16, mul_mm_id_q3_K_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16, mul_mm_id_q4_K_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU_OAI, swiglu_oai, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF, geglu_erf, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK, geglu_quick, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_SWIGLU:
+ case GGML_GLU_OP_SWIGLU_OAI:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
+ case GGML_OP_ADD_ID:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ACC:
case GGML_OP_REPEAT:
const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
+ const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT;
const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
size_t offs_src0 = 0;
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
}
} break;
+ case GGML_OP_ADD_ID:
+ {
+ GGML_ASSERT(src0t == GGML_TYPE_F32);
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
+ GGML_ASSERT(src2t == GGML_TYPE_I32);
+ GGML_ASSERT(dstt == GGML_TYPE_F32);
+
+ GGML_ASSERT(ggml_is_contiguous_rows(src0));
+
+ id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ID].pipeline;
+
+ ggml_metal_kargs_add_id args = {
+ /*.ne0 =*/ ne0,
+ /*.ne1 =*/ ne1,
+ /*.nb01 =*/ nb01,
+ /*.nb02 =*/ nb02,
+ /*.nb11 =*/ nb11,
+ /*.nb21 =*/ nb21,
+
+ };
+
+ [encoder setComputePipelineState:pipeline];
+ [encoder setBytes:&args length:sizeof(args) atIndex:0];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
+
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
case GGML_OP_REPEAT:
{
id<MTLComputePipelineState> pipeline;
case GGML_GLU_OP_SWIGLU:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
break;
+ case GGML_GLU_OP_SWIGLU_OAI:
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU_OAI].pipeline;
+ break;
case GGML_GLU_OP_GEGLU_ERF:
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU_ERF].pipeline;
break;
GGML_ABORT("fatal error");
}
- const int32_t swp = ((const int32_t *) dst->op_params)[1];
+ const int32_t swp = ggml_get_op_params_i32(dst, 1);
+ const float alpha = ggml_get_op_params_f32(dst, 2);
+ const float limit = ggml_get_op_params_f32(dst, 3);
const int32_t i00 = swp ? ne0 : 0;
const int32_t i10 = swp ? 0 : ne0;
/*.nb1 =*/ nb1,
/*.i00 =*/ src1 ? 0 : i00,
/*.i10 =*/ src1 ? 0 : i10,
+ /*.alpha=*/ alpha,
+ /*.limit=*/ limit
};
[encoder setComputePipelineState:pipeline];
} else {
[encoder setBuffer:h_src0 offset:offs_src0 atIndex:1];
}
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&args length:sizeof(args) atIndex:3];
+ if (id_src2) {
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
+ } else {
+ [encoder setBuffer:h_src0 offset:offs_src0 atIndex:2];
+ }
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
+ [encoder setBytes:&args length:sizeof(args) atIndex:4];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
src0t == GGML_TYPE_Q5_0 ||
src0t == GGML_TYPE_Q5_1 ||
src0t == GGML_TYPE_Q8_0 ||
+ src0t == GGML_TYPE_MXFP4 ||
src0t == GGML_TYPE_IQ4_NL ||
false) && (ne11 >= 2 && ne11 <= 8)
) ||
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break;
default: GGML_ABORT("not implemented");
} break;
+ case GGML_TYPE_MXFP4:
+ switch (r1ptg) {
+ case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_2].pipeline; break;
+ case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_3].pipeline; break;
+ case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_4].pipeline; break;
+ case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_MXFP4_F32_R1_5].pipeline; break;
+ default: GGML_ABORT("not implemented");
+ } break;
case GGML_TYPE_Q4_K:
switch (r1ptg) {
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_K_F32_R1_2].pipeline; break;
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
+ case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break;
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
nr0 = N_R0_Q8_0;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline;
} break;
+ case GGML_TYPE_MXFP4:
+ {
+ nsg = N_SG_MXFP4;
+ nr0 = N_R0_MXFP4;
+ smem = 32*sizeof(float);
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32].pipeline;
+ } break;
case GGML_TYPE_Q2_K:
{
nsg = N_SG_Q2_K;
case GGML_OP_MUL_MAT_ID:
{
// src2 = ids
- const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
-
GGML_ASSERT(src2t == GGML_TYPE_I32);
GGML_ASSERT(!ggml_is_transposed(src0));
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F16 ].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F16 ].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F16 ].pipeline; break;
+ case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MXFP4_F16 ].pipeline; break;
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F16 ].pipeline; break;
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F16 ].pipeline; break;
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F16 ].pipeline; break;
nr0 = N_R0_Q8_0;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline;
} break;
+ case GGML_TYPE_MXFP4:
+ {
+ nsg = N_SG_MXFP4;
+ nr0 = N_R0_MXFP4;
+ smem = 32*sizeof(float);
+ pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_MXFP4_F32].pipeline;
+ } break;
case GGML_TYPE_Q2_K:
{
nsg = N_SG_Q2_K;
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0 ].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1 ].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0 ].pipeline; break;
+ case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_MXFP4 ].pipeline; break;
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K ].pipeline; break;
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K ].pipeline; break;
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K ].pipeline; break;
GGML_ASSERT(ne11 == ne21);
GGML_ASSERT(ne12 == ne22);
- struct ggml_tensor * src3 = node->src[3];
+ struct ggml_tensor * src3 = node->src[3]; // mask
+ struct ggml_tensor * src4 = node->src[4]; // sinks
size_t offs_src3 = 0;
+ size_t offs_src4 = 0;
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
+ id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16);
GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32);
const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33);
- const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
-
float scale;
float max_bias;
float logit_softcap;
} else {
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:4];
}
- [encoder setBuffer:id_dst offset:offs_dst atIndex:5];
+ if (id_src4) {
+ [encoder setBuffer:id_src4 offset:offs_src4 atIndex:5];
+ } else {
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:5];
+ }
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:6];
if (!use_vec_kernel) {
// half8x8 kernel
-127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
};
+constexpr constant static float kvalues_mxfp4_f[16] = {
+ 0, .5f, 1.f, 1.5f, 2.f, 3.f, 4.f, 6.f, -0, -.5f, -1.f, -1.5f, -2.f, -3.f, -4.f, -6.f
+};
+
static inline int best_index_int8(int n, constant float * val, float x) {
if (x <= val[0]) return 0;
if (x >= val[n-1]) return n-1;
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
}
+static inline float e8m0_to_fp32(uint8_t x) {
+ uint32_t bits;
+
+ if (x == 0) {
+ bits = 0x00400000;
+ } else {
+ bits = (uint32_t) x << 23;
+ }
+
+ return as_type<float>(bits);
+}
+
// NOTE: this is not dequantizing - we are simply fitting the template
template <typename type4x4>
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
}
}
+void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
+#pragma METAL fp math_mode(safe)
+ float amax = 0.0f; // absolute max
+
+ for (int j = 0; j < QK8_0; j++) {
+ const float v = src[j];
+ amax = MAX(amax, fabs(v));
+ }
+
+ const float d = amax / ((1 << 7) - 1);
+ const float id = d ? 1.0f/d : 0.0f;
+
+ dst.d = d;
+
+ for (int j = 0; j < QK8_0; ++j) {
+ const float x0 = src[j]*id;
+
+ dst.qs[j] = round(x0);
+ }
+}
+
void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
#pragma METAL fp math_mode(safe)
float amax = 0.0f; // absolute max
}
}
-void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
-#pragma METAL fp math_mode(safe)
- float amax = 0.0f; // absolute max
+template <typename type4x4>
+void dequantize_mxfp4(device const block_mxfp4 * xb, short il, thread type4x4 & reg) {
+ device const uint8_t * q2 = (device const uint8_t *)xb->qs;
- for (int j = 0; j < QK8_0; j++) {
- const float v = src[j];
- amax = MAX(amax, fabs(v));
+ const float d = e8m0_to_fp32(xb->e);
+ const uint8_t shr = il >= 1 ? 4 : 0;
+
+ for (int i = 0; i < 4; ++i) {
+ reg[i][0] = d * kvalues_mxfp4_f[(q2[4*i + 0] >> shr) & 0x0F];
+ reg[i][1] = d * kvalues_mxfp4_f[(q2[4*i + 1] >> shr) & 0x0F];
+ reg[i][2] = d * kvalues_mxfp4_f[(q2[4*i + 2] >> shr) & 0x0F];
+ reg[i][3] = d * kvalues_mxfp4_f[(q2[4*i + 3] >> shr) & 0x0F];
}
+}
- const float d = amax / ((1 << 7) - 1);
- const float id = d ? 1.0f/d : 0.0f;
+template <typename type4>
+void dequantize_mxfp4_t4(device const block_mxfp4 * xb, short il, thread type4 & reg) {
+ device const uint8_t * q2 = (device const uint8_t *)xb->qs;
- dst.d = d;
+ const float d = e8m0_to_fp32(xb->e);
+ const short il4 = il%4;
- for (int j = 0; j < QK8_0; ++j) {
- const float x0 = src[j]*id;
+ const uint8_t shr = il >= 4 ? 4 : 0;
- dst.qs[j] = round(x0);
- }
+ reg[0] = d * kvalues_mxfp4_f[(q2[4*il4 + 0] >> shr) & 0x0F];
+ reg[1] = d * kvalues_mxfp4_f[(q2[4*il4 + 1] >> shr) & 0x0F];
+ reg[2] = d * kvalues_mxfp4_f[(q2[4*il4 + 2] >> shr) & 0x0F];
+ reg[3] = d * kvalues_mxfp4_f[(q2[4*il4 + 3] >> shr) & 0x0F];
}
template <typename type4x4>
}
}
+kernel void kernel_add_id(
+ constant ggml_metal_kargs_add_id & args,
+ device const char * src0,
+ device const char * src1,
+ device const char * src2,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+ const int i1 = tgpig.x;
+ const int i2 = tgpig.y;
+
+ const int i11 = *((device const int32_t *) (src2 + i1*sizeof(int32_t) + i2*args.nb21));
+
+ const size_t nb1 = args.ne0 * sizeof(float);
+ const size_t nb2 = args.ne1 * nb1;
+
+ device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2);
+ device const float * src0_row = (device const float *)((device char *)src0 + i1*args.nb01 + i2*args.nb02);
+ device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11);
+
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+ dst_row[i0] = src0_row[i0] + src1_row[i0];
+ }
+}
+
template<typename T>
kernel void kernel_repeat(
constant ggml_metal_kargs_repeat & args,
}
}
+kernel void kernel_swiglu_oai(
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ constant ggml_metal_kargs_glu & args,
+ uint tgpig[[threadgroup_position_in_grid]],
+ uint tpitg[[thread_position_in_threadgroup]],
+ uint ntg[[threads_per_threadgroup]]) {
+ device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
+ device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
+ device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
+
+ for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
+ float x0 = src0_row[i0];
+ float x1 = src1_row[i0];
+
+ x0 = min(x0, args.limit);
+ x1 = max(min(x1, args.limit), -args.limit);
+
+ float out_glu = x0 / (1.0f + exp(-x0 * args.alpha));
+ out_glu = out_glu * (1.0f + x1);
+
+ dst_row[i0] = out_glu;
+ }
+}
+
kernel void kernel_geglu_erf(
device const char * src0,
device const char * src1,
kernel void kernel_soft_max(
device const char * src0,
device const char * src1,
+ device const char * src2,
device char * dst,
constant ggml_metal_kargs_soft_max & args,
threadgroup float * buf [[threadgroup(0)]],
device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
+ device const float * psrc2 = src2 != src0 ? (device const float *) (src2) : nullptr;
device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
float slope = 1.0f;
}
// parallel max
- float lmax = -INFINITY;
+ float lmax = psrc2 ? psrc2[i02] : -INFINITY;
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
sum = simd_sum(sum);
}
+ if (psrc2) {
+ sum += exp(psrc2[i02] - max_val);
+ }
+
const float inv_sum = 1.0f/sum;
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
kernel void kernel_soft_max_4(
device const char * src0,
device const char * src1,
+ device const char * src2,
device char * dst,
constant ggml_metal_kargs_soft_max & args,
threadgroup float * buf [[threadgroup(0)]],
device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
+ device const float * psrc2 = src2 != src0 ? (device const float * ) (src2) : nullptr;
device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
float slope = 1.0f;
}
// parallel max
- float4 lmax4 = -INFINITY;
+ float4 lmax4 = psrc2 ? psrc2[i02] : -INFINITY;
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
sum = simd_sum(sum);
}
+ if (psrc2) {
+ sum += exp(psrc2[i02] - max_val);
+ }
+
const float inv_sum = 1.0f/sum;
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_q8_0, 32, dequantize_q8_0_t4>;
template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_q8_0, 32, dequantize_q8_0_t4>;
+template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_mxfp4, 32, dequantize_mxfp4_t4>;
+template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_mxfp4, 32, dequantize_mxfp4_t4>;
+template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_mxfp4, 32, dequantize_mxfp4_t4>;
+template [[host_name("kernel_mul_mv_ext_mxfp4_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, block_mxfp4, 32, dequantize_mxfp4_t4>;
+
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
template [[host_name("kernel_mul_mv_ext_iq4_nl_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, block_iq4_nl, 32, dequantize_iq4_nl_t4>;
device const char * k,
device const char * v,
device const char * mask,
+ device const char * sinks,
device char * dst,
threadgroup half * shmem_f16 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
}
}
+ if (sinks != q && sgitg == 0) {
+ for (ushort j = 0; j < Q; ++j) {
+ const float m = M[j];
+ const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
+
+ M[j] = simd_max(max(M[j], s));
+
+ const float ms = exp(m - M[j]);
+ const float vs = exp(s - M[j]);
+
+ S[j] = S[j]*ms + simd_sum(vs);
+
+ if (tiisg == j) {
+ ss[j*TS + 2*C + j] = ms;
+ }
+ }
+
+ // O = diag(ms)*O
+ {
+ s8x8_t ms;
+ simdgroup_load(ms, ss + 2*C, TS, 0, false);
+
+ #pragma unroll(DV8)
+ for (short i = 0; i < DV8; ++i) {
+ simdgroup_multiply(lo[i], ms, lo[i]);
+ }
+ }
+ }
+
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
for (short j = tiisg; j < Q; j += NW) {
ss[j*TS + 0] = S[j];
device const char * k,
device const char * v,
device const char * mask,
+ device const char * sinks,
device char * dst,
threadgroup half * shmem_f16 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
}
}
+ if (sinks != q && sgitg == 0) {
+ const float m = M;
+ const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
+
+ M = simd_max(max(M, s));
+
+ const float ms = exp(m - M);
+ const float vs = exp(s - M);
+
+ S = S*ms + simd_sum(vs);
+
+#pragma unroll(DV4/NL)
+ for (short ii = 0; ii < DV4; ii += NL) {
+ lo[ii/NL] *= ms;
+ }
+ }
+
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
if (tiisg == 0) {
ss[0] = (s_t) S;
kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, N_SG_IQ4_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
+template<int nr0, int nsg, int nw, typename args_t>
+void kernel_mul_mv_mxfp4_f32_impl(
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
+
+ threadgroup float * shmem_f32 = (threadgroup float *) shmem;
+ const int nb = args.ne00/QK_MXFP4;
+
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0 * nsg + sgitg) * nr0;
+
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
+
+ const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
+
+ device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0);
+ device const float * y = (device const float *) (src1 + offset1);
+
+ const short ix = tiisg/2; // 0...15
+ const short it = tiisg%2; // 0 or 1
+
+ shmem_f32[tiisg] = kvalues_mxfp4_f[tiisg%16];
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ float4 yl[4];
+ float sumf[nr0]={0.f};
+
+ device const float * yb = y + ix * QK_MXFP4 + it * 8;
+
+ for (int ib = ix; ib < nb; ib += 16) {
+ device const float4 * y4 = (device const float4 *)yb;
+ yl[0] = y4[0];
+ yl[1] = y4[4];
+ yl[2] = y4[1];
+ yl[3] = y4[5];
+
+#pragma unroll(nr0)
+ for (short row = 0; row < nr0; row++) {
+ device const block_mxfp4 & xb = x[row*nb + ib];
+ device const uint8_t * q2 = (device const uint8_t *)(xb.qs + 8*it);
+
+ float4 acc1 = yl[0]*float4(shmem_f32[q2[0] & 0x0F], shmem_f32[q2[1] & 0x0F], shmem_f32[q2[2] & 0x0F], shmem_f32[q2[3] & 0x0F]);
+ float4 acc2 = yl[1]*float4(shmem_f32[q2[0] >> 4 ], shmem_f32[q2[1] >> 4 ], shmem_f32[q2[2] >> 4 ], shmem_f32[q2[3] >> 4 ]);
+ float4 acc3 = yl[2]*float4(shmem_f32[q2[4] & 0x0F], shmem_f32[q2[5] & 0x0F], shmem_f32[q2[6] & 0x0F], shmem_f32[q2[7] & 0x0F]);
+ float4 acc4 = yl[3]*float4(shmem_f32[q2[4] >> 4 ], shmem_f32[q2[5] >> 4 ], shmem_f32[q2[6] >> 4 ], shmem_f32[q2[7] >> 4 ]);
+
+ acc1 = (acc1 + acc3) + (acc2 + acc4);
+
+ sumf[row] += e8m0_to_fp32(xb.e) * ((acc1[0] + acc1[1]) + (acc1[2] + acc1[3]));
+ }
+
+ yb += 16 * QK_MXFP4;
+ }
+
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+ for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) {
+ float sum_all = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst_f32[first_row + row] = sum_all;
+ }
+ }
+}
+
+[[host_name("kernel_mul_mv_mxfp4_f32")]]
+kernel void kernel_mul_mv_mxfp4_f32(
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, N_SG_MXFP4, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+}
+
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
kernel void kernel_get_rows_q(
constant ggml_metal_kargs_get_rows & args,
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
+template [[host_name("kernel_get_rows_mxfp4")]] kernel get_rows_q_t kernel_get_rows_q<block_mxfp4, 2, dequantize_mxfp4>;
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
+template [[host_name("kernel_mul_mm_mxfp4_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4>;
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_mul_mm_id_q5_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
template [[host_name("kernel_mul_mm_id_q8_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
+template [[host_name("kernel_mul_mm_id_mxfp4_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_mxfp4, 2, dequantize_mxfp4>;
template [[host_name("kernel_mul_mm_id_q2_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
template [[host_name("kernel_mul_mm_id_q3_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
template [[host_name("kernel_mul_mm_id_q4_K_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SG_Q5_0, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SG_Q5_1, N_SIMDWIDTH>>>;
+template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, N_SG_MXFP4, N_SIMDWIDTH>>>;
+
template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl <N_R0_Q2_K, N_SG_Q2_K, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl <N_R0_Q3_K, N_SG_Q3_K, N_SIMDWIDTH>>>;
template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl <N_R0_Q4_K, N_SG_Q4_K, N_SIMDWIDTH>>>;
case GGML_OP_CLAMP:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SOFT_MAX:
+ // TODO: support attention sinks [TAG_ATTN_SINKS]
+ return op->src[2] == nullptr;
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
return true;
#define UNUSED GGML_UNUSED
+static inline int best_index_int8(int n, const int8_t * val, float x) {
+ if (x <= val[0]) return 0;
+ if (x >= val[n-1]) return n-1;
+ int ml = 0, mu = n-1;
+ while (mu-ml > 1) {
+ int mav = (ml+mu)/2;
+ if (x < val[mav]) mu = mav; else ml = mav;
+ }
+ return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
+}
+
// reference implementation for deterministic creation of model files
void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k) {
static const int qk = QK4_0;
}
}
+static inline int best_index_mxfp4(float x, float e) {
+ int best_index = 0;
+ float best_err = fabsf(kvalues_mxfp4[0]*e - x);
+ for (int i = 1; i < 16; i++) {
+ float err = fabsf(kvalues_mxfp4[i]*e - x);
+ if (err < best_err) {
+ best_index = i;
+ best_err = err;
+ }
+ }
+ return best_index;
+}
+
+void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k) {
+ static const int qk = QK_MXFP4;
+
+ assert(k % qk == 0);
+
+ const int nb = k / qk;
+
+ for (int i = 0; i < nb; i++) {
+ float amax = 0.0f; // absolute max
+
+ for (int j = 0; j < qk; j++) {
+ const float v = x[i*qk + j];
+
+ if (amax < fabsf(v)) {
+ amax = fabsf(v);
+ }
+ }
+
+ const uint8_t e = (uint8_t) (floorf(log2f(amax)) - 2 + 127);
+
+ const float d = GGML_E8M0_TO_FP32_HALF(e);
+
+ y[i].e = e;
+
+ for (int j = 0; j < qk/2; ++j) {
+ const uint8_t x0 = best_index_mxfp4(x[i*qk + 0 + j], d);
+ const uint8_t x1 = best_index_mxfp4(x[i*qk + qk/2 + j], d);
+
+ y[i].qs[j] = x0;
+ y[i].qs[j] |= x1 << 4;
+ }
+ }
+}
+
void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
static const int qk = QK4_0;
}
}
+void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) {
+ static const int qk = QK_MXFP4;
+
+ assert(k % qk == 0);
+
+ const int nb = k / qk;
+
+ for (int i = 0; i < nb; i++) {
+ const float d = GGML_E8M0_TO_FP32_HALF(x[i].e);
+
+ for (int j = 0; j < qk/2; ++j) {
+ const int8_t x0 = kvalues_mxfp4[x[i].qs[j] & 0x0F];
+ const int8_t x1 = kvalues_mxfp4[x[i].qs[j] >> 4];
+
+ y[i*qk + j + 0 ] = x0*d;
+ y[i*qk + j + qk/2] = x1*d;
+ }
+ }
+}
+
//
// 2-6 bit quantization in super-blocks
//
return nrow * row_size;
}
+size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
+ GGML_UNUSED(quant_weights);
+ quantize_row_mxfp4_ref(src, dst, (int64_t)nrow*n_per_row);
+ return nrow * ggml_row_size(GGML_TYPE_MXFP4, n_per_row);
+}
+
// ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k) {
// ============================ 4-bit non-linear quants
-static inline int best_index_int8(int n, const int8_t * val, float x) {
- if (x <= val[0]) return 0;
- if (x >= val[n-1]) return n-1;
- int ml = 0, mu = n-1;
- while (mu-ml > 1) {
- int mav = (ml+mu)/2;
- if (x < val[mav]) mu = mav; else ml = mav;
- }
- return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
-}
-
static void quantize_row_iq4_nl_impl(const int super_block_size, const int block_size, const float * GGML_RESTRICT x,
ggml_fp16_t * dh, uint8_t * q4, uint16_t * scales_h, uint8_t * scales_l,
float * scales, float * weight, uint8_t * L,
return true;
}
+static bool validate_e_e8m0(uint8_t e, size_t i) {
+ if (e == 0xff) {
+ fprintf(stderr, "ggml_validate_row_data: found invalid e value %d at block %zu\n", e, i);
+ return false;
+ }
+
+ return true;
+}
+
#define VALIDATE_ROW_DATA_D_F16_IMPL(type, data, nb) \
const type * q = (const type *) (data); \
for (size_t i = 0; i < (nb); ++i) { \
} \
}
+#define VALIDATE_ROW_DATA_E_E8M0_IMPL(type, data, nb) \
+ const type * q = (const type *) (data); \
+ for (size_t i = 0; i < (nb); ++i) { \
+ if (!validate_e_e8m0(q[i].e, i)) { \
+ return false; \
+ } \
+ }
+
#define VALIDATE_ROW_DATA_DVEC_F16_IMPL(type, data, nb, nr) \
const type * q = (const type *) (data); \
for (size_t i = 0; i < (nb); ++i) { \
{
VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb);
} break;
+ case GGML_TYPE_MXFP4:
+ {
+ VALIDATE_ROW_DATA_E_E8M0_IMPL(block_mxfp4, data, nb);
+ } break;
case GGML_TYPE_Q2_K:
{
VALIDATE_ROW_DATA_DM_F16_IMPL(block_q2_K, data, nb, d, dmin);
GGML_API void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k);
+GGML_API void quantize_row_mxfp4_ref(const float * GGML_RESTRICT x, block_mxfp4 * GGML_RESTRICT y, int64_t k);
+
GGML_API void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k);
GGML_API void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
//GGML_API void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+GGML_API void dequantize_row_mxfp4(const block_mxfp4 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
+
GGML_API void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_q3_K(const block_q3_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API void dequantize_row_q4_K(const block_q4_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
GGML_API size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
GGML_API size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+GGML_API size_t quantize_mxfp4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
+
GGML_API void iq2xs_init_impl(enum ggml_type type);
GGML_API void iq2xs_free_impl(enum ggml_type type);
GGML_API void iq3xs_init_impl(int grid_size);
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
{
- struct ggml_tensor * a;
- struct ggml_tensor * b;
- if (op->op == GGML_OP_MUL_MAT) {
- a = op->src[0];
- b = op->src[1];
- } else {
- a = op->src[2];
- b = op->src[1];
- }
+ struct ggml_tensor * a = op->src[0];
+ struct ggml_tensor * b = op->src[1];
+
if (a->ne[3] != b->ne[3]) {
return false;
}
}
}
ggml_type src0_type = op->src[0]->type;
- if (src0_type == GGML_TYPE_BF16) {
+ if (src0_type == GGML_TYPE_BF16 || src0_type == GGML_TYPE_MXFP4) {
+ // TODO: support MXFP4
+ // FIXME: keep a list of supported types to avoid breaking the backend when a new type is added
return false;
}
return true;
if (op->src[0]->ne[3] != 1) {
return false;
}
+ // TODO: support attention sinks [TAG_ATTN_SINKS]
+ if (op->src[2]) {
+ return false;
+ }
// TODO: support broadcast
// ref: https://github.com/ggml-org/llama.cpp/pull/14435
return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
vk_pipeline pipeline_div[2][2][2];
vk_pipeline pipeline_div_norepeat[2][2][2];
+ vk_pipeline pipeline_add_id_f32;
+
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32;
vk_pipeline pipeline_scale_f32;
vk_pipeline pipeline_geglu[2];
vk_pipeline pipeline_reglu[2];
vk_pipeline pipeline_swiglu[2];
+ vk_pipeline pipeline_swiglu_oai[2];
vk_pipeline pipeline_geglu_erf[2];
vk_pipeline pipeline_geglu_quick[2];
uint32_t ne00;
uint32_t ne20;
uint32_t mode; // 0: default, 1: swapped, 2: split
+ float alpha; // for swiglu_oai
+ float limit;
};
struct vk_op_unary_push_constants {
float param1; float param2; int32_t param3;
};
+struct vk_op_add_id_push_constants {
+ uint32_t ne0;
+ uint32_t ne1;
+ uint32_t s01;
+ uint32_t s02;
+ uint32_t s11;
+ uint32_t s21;
+};
+
struct vk_op_diag_mask_push_constants {
uint32_t ncols;
uint32_t rows_per_channel;
float m1;
uint32_t n_head_log2;
uint32_t nrows_x;
+ uint32_t has_sinks;
};
struct vk_op_argsort_push_constants {
break;
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_IQ4_XS:
+ case GGML_TYPE_MXFP4:
lut_size = 4*16;
break;
default:
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S], matmul_iq3_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_MXFP4], matmul_mxfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
#undef CREATE_MM
#undef CREATE_MM2
} else
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+ CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
} else {
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+ CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
}
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
} else {
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
}
#undef CREATE_MM2
#undef CREATE_MM
CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+ CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f16acc, matmul_id_mxfp4_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
#undef CREATE_MM2
#undef CREATE_MMQ
#undef CREATE_MM
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
+ CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
+ CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
}
// reusing CREATE_MM from the fp32 path
if ((device->coopmat2 || device->coopmat_support)
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f32_f32_len, mul_mat_vec_iq4_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32_"+std::to_string(i+1), mul_mat_vec_mxfp4_f32_f32_len, mul_mat_vec_mxfp4_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f16_f32_len, mul_mat_vec_iq4_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32_"+std::to_string(i+1), mul_mat_vec_mxfp4_f16_f32_len, mul_mat_vec_mxfp4_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
}
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", mul_mat_vec_id_mxfp4_f32_len, mul_mat_vec_id_mxfp4_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
// dequant shaders
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_S], "dequant_iq3_s", dequant_iq3_s_len, dequant_iq3_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4], "dequant_mxfp4", dequant_mxfp4_len, dequant_mxfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
// get_rows
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_S], "get_rows_iq3_s", get_rows_iq3_s_len, get_rows_iq3_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4", get_rows_mxfp4_len, get_rows_mxfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_S], "get_rows_iq3_s_f32", get_rows_iq3_s_f32_len, get_rows_iq3_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs_f32", get_rows_iq4_xs_f32_len, get_rows_iq4_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 4 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
CREATE_BINARY(div, _norepeat, {1})
#undef CREATE_BINARY
+ ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1);
+
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
CREATE_GLU(geglu)
CREATE_GLU(reglu)
CREATE_GLU(swiglu)
+ CREATE_GLU(swiglu_oai)
CREATE_GLU(geglu_erf)
CREATE_GLU(geglu_quick)
#undef CREATE_GLU
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
- ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_MXFP4:
break;
default:
return nullptr;
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_MXFP4:
break;
default:
return nullptr;
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_MXFP4:
break;
default:
return nullptr;
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_MXFP4:
break;
default:
return nullptr;
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_MXFP4:
break;
default:
return nullptr;
std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))");
GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size());
GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT);
+ GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size());
vk::DescriptorSet& descriptor_set = ctx->descriptor_sets[ctx->descriptor_set_idx++];
vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() };
break;
}
return nullptr;
+ case GGML_OP_ADD_ID:
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && src2->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_F32) {
+ return ctx->device->pipeline_add_id_f32;
+ }
+ return nullptr;
case GGML_OP_CONCAT:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_concat_f32;
return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
case GGML_GLU_OP_SWIGLU:
return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
+ case GGML_GLU_OP_SWIGLU_OAI:
+ return ctx->device->pipeline_swiglu_oai[dst->type == GGML_TYPE_F16];
case GGML_GLU_OP_GEGLU_ERF:
return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16];
case GGML_GLU_OP_GEGLU_QUICK:
return nullptr;
case GGML_OP_SOFT_MAX:
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
+ GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32);
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
+ case GGML_OP_ADD_ID:
case GGML_OP_CONCAT:
case GGML_OP_UPSCALE:
case GGML_OP_SQR:
elements = { ne, 1, 1 };
}
} break;
+ case GGML_OP_ADD_ID:
+ {
+ elements = { (uint32_t)ne01, (uint32_t)ne02, 1 };
+ } break;
case GGML_OP_SET_ROWS:
{
uint32_t ne = ggml_nelements(src0);
}
}
- if (op == GGML_OP_SOFT_MAX || op == GGML_OP_GLU) {
- // Empty src1 is possible in soft_max, but the shader needs a buffer
+ if (op == GGML_OP_GLU) {
+ // Empty src1 is possible in glu, but the shader needs a buffer
vk_subbuffer subbuf_y;
if (use_src1) {
subbuf_y = { d_Y, y_buf_offset, y_sz };
ggml_vk_sync_buffers(subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
+ } else if (op == GGML_OP_SOFT_MAX) {
+ // Empty src1 and src2 is possible in soft_max, but the shader needs a buffer
+ vk_subbuffer subbuf_y;
+ if (use_src1) {
+ subbuf_y = { d_Y, y_buf_offset, y_sz };
+ } else {
+ subbuf_y = { d_X, 0, x_sz };
+ }
+
+ vk_subbuffer subbuf_z;
+ if (use_src2) {
+ subbuf_z = { d_Z, z_buf_offset, z_sz };
+ } else {
+ subbuf_z = { d_X, 0, x_sz };
+ }
+
+ ggml_vk_sync_buffers(subctx);
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
// Empty src2 is possible in rope, but the shader needs a buffer
vk_subbuffer subbuf_z;
}, dryrun);
}
+static void ggml_vk_add_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
+ const uint32_t src1_type_size = ggml_type_size(src1->type);
+ const uint32_t src2_type_size = ggml_type_size(src2->type);
+
+ ggml_vk_op_f32<vk_op_add_id_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_ADD_ID, {
+ (uint32_t)dst->ne[0],
+ (uint32_t)dst->ne[1],
+ (uint32_t)src0->nb[1] / src0_type_size,
+ (uint32_t)src0->nb[2] / src0_type_size,
+ (uint32_t)src1->nb[1] / src1_type_size,
+ (uint32_t)src2->nb[1] / src2_type_size,
+ }, dryrun);
+}
+
static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version, bool dryrun = false) {
GGML_ASSERT(version == 6 || version == 7);
int num_srcs = version == 6 ? 6 : 7;
}
static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
+ const float * op_params_f = (const float *)dst->op_params;
+
const bool swapped = (bool)dst->op_params[1];
const bool split = src1 != nullptr;
+ const float alpha = op_params_f[2];
+ const float limit = op_params_f[3];
GGML_ASSERT(ggml_is_contiguous(src0));
const uint32_t mode = split ? 2 : (swapped ? 1 : 0);
- ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, { (uint32_t)ggml_nelements(dst), (uint32_t)src0->ne[0], (uint32_t)dst->ne[0], mode }, dryrun);
+ ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU,
+ {
+ (uint32_t)ggml_nelements(dst),
+ (uint32_t)src0->ne[0],
+ (uint32_t)dst->ne[0],
+ mode,
+ alpha,
+ limit
+ }, dryrun);
}
static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
}
-static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
+static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
float * op_params = (float *)dst->op_params;
float scale = op_params[0];
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
- ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
+ ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, {
ncols,
src1 != nullptr ? nrows_y : (uint32_t)0,
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
m0, m1,
n_head_log2,
nrows_x,
+ src2 != nullptr
}, dryrun);
}
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_SWIGLU:
+ case GGML_GLU_OP_SWIGLU_OAI:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
break;
case GGML_OP_REPEAT_BACK:
case GGML_OP_GET_ROWS:
case GGML_OP_ADD:
+ case GGML_OP_ADD_ID:
case GGML_OP_ACC:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
ggml_vk_div(ctx, compute_ctx, src0, src1, node, dryrun);
+ break;
+ case GGML_OP_ADD_ID:
+ ggml_vk_add_id(ctx, compute_ctx, src0, src1, src2, node, dryrun);
+
break;
case GGML_OP_CONCAT:
ggml_vk_concat(ctx, compute_ctx, src0, src1, node, dryrun);
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_SWIGLU:
+ case GGML_GLU_OP_SWIGLU_OAI:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
break;
case GGML_OP_SOFT_MAX:
- ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun);
+ ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node, dryrun);
break;
case GGML_OP_SOFT_MAX_BACK:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
+ case GGML_OP_ADD_ID:
case GGML_OP_CONCAT:
case GGML_OP_UPSCALE:
case GGML_OP_SCALE:
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_SWIGLU:
+ case GGML_GLU_OP_SWIGLU_OAI:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
buf = tensor->buffer;
case GGML_GLU_OP_GEGLU:
case GGML_GLU_OP_REGLU:
case GGML_GLU_OP_SWIGLU:
+ case GGML_GLU_OP_SWIGLU_OAI:
case GGML_GLU_OP_GEGLU_ERF:
case GGML_GLU_OP_GEGLU_QUICK:
return ggml_is_contiguous(op->src[0]) &&
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_MXFP4:
break;
default:
return false;
if (head_sizes == FA_HEAD_SIZE_UNSUPPORTED) {
return false;
}
+ // TODO: support attention sinks [TAG_ATTN_SINKS]
+ if (op->src[4]) {
+ return false;
+ }
if (op->src[0]->type != GGML_TYPE_F32) {
return false;
}
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
+ case GGML_TYPE_MXFP4:
return true;
default:
return false;
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
(op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
+ case GGML_OP_ADD_ID:
+ return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->src[2]->type == GGML_TYPE_I32 &&
+ op->type == GGML_TYPE_F32;
case GGML_OP_SILU_BACK:
case GGML_OP_RMS_NORM_BACK:
case GGML_OP_SQR:
--- /dev/null
+#version 450
+
+#extension GL_EXT_control_flow_attributes : require
+
+#include "types.comp"
+
+layout (push_constant) uniform parameter
+{
+ uint ne0;
+ uint ne1;
+ uint s01;
+ uint s02;
+ uint s11;
+ uint s21;
+} p;
+
+#define BLOCK_SIZE 512
+
+layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
+layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
+layout (binding = 2) readonly buffer Z {int32_t data_c[];};
+layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};
+
+void main() {
+ const uint i1 = gl_WorkGroupID.x;
+ const uint i2 = gl_WorkGroupID.y;
+
+ const uint i11 = data_c[i1 + i2 * p.s21];
+
+ const uint s1 = p.ne0;
+ const uint s2 = p.ne0 * p.ne1;
+
+ const uint d0 = i1 * s1 + i2 * s2;
+ const uint a0 = i1 * p.s01 + i2 * p.s02;
+ const uint b0 = i11 * p.s11;
+
+ for (uint i0 = gl_LocalInvocationID.x; i0 < p.ne0; i0 += BLOCK_SIZE) {
+ data_d[d0 + i0] = data_a[a0 + i0] + data_b[b0 + i0];
+ }
+}
#include "generic_unary_head.comp"
#include "dequant_funcs.comp"
-#if defined(DATA_A_IQ4_NL)
-// 16 invocations needed for init_iq4nl_shmem
+#if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4)
+// 16 invocations needed for init_iq_shmem
layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in;
#else
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
}
#endif
+#if defined(DATA_A_MXFP4)
+vec2 dequantize(uint ib, uint iqs, uint a_offset) {
+ const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
+ return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]);
+}
+vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
+ vec2 v0 = dequantize(ib, iqs, a_offset);
+ vec2 v1 = dequantize(ib, iqs + 1, a_offset);
+ return vec4(v0.x, v0.y, v1.x, v1.y);
+}
+#endif
+
#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)
vec2 get_dm(uint ib, uint a_offset) {
return vec2(0, 0);
}
#endif
+#if defined(DATA_A_MXFP4)
+vec2 get_dm(uint ib, uint a_offset) {
+ return vec2(e8m0_to_fp32(data_a[a_offset + ib].e), 0);
+}
+#endif
+
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
vec2 get_dm(uint ib, uint a_offset) {
return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m));
}
#endif
+#if defined(DATA_A_MXFP4)
+layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufMXFP4 {
+ block_mxfp4 block;
+};
+
+float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
+{
+ const float d = e8m0_to_fp32(bl.block.e);
+ const uint idx = coordInBlock[1];
+ const uint iqs = idx & 0xF;
+ const uint shift = (idx & 0x10) >> 2;
+ uint32_t qs = bl.block.qs[iqs];
+ qs >>= shift;
+ qs &= 0xF;
+ float16_t ret = float16_t(kvalues_mxfp4[qs] * d);
+ return ret;
+}
+#endif
+
#if defined(DATA_A_Q4_0)
#define dequantFuncA dequantFuncQ4_0
#elif defined(DATA_A_Q4_1)
#define dequantFuncA dequantFuncIQ4_XS
#elif defined(DATA_A_IQ4_NL)
#define dequantFuncA dequantFuncIQ4_NL
+#elif defined(DATA_A_MXFP4)
+#define dequantFuncA dequantFuncMXFP4
#endif
--- /dev/null
+#version 450
+
+#include "dequant_head.comp"
+
+layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
+
+layout (binding = 0) readonly buffer A {block_mxfp4 data_a[];};
+layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
+
+void main() {
+ const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
+
+ init_iq_shmem(gl_WorkGroupSize);
+
+ const uint tid = gl_LocalInvocationID.x % 64;
+ const uint il = tid/32;
+ const uint ir = tid%32;
+ const uint ib = 32*i + ir;
+ if (ib >= p.nel / 32) {
+ return;
+ }
+
+ const uint q_idx = 8*il;
+ const uint b_idx = 1024*i + 32*ir + q_idx;
+
+ const float d = e8m0_to_fp32(data_a[ib].e);
+
+ [[unroll]] for (uint l = 0; l < 8; ++l) {
+ data_b[b_idx + l + 0] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]);
+ data_b[b_idx + l + 16] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]);
+ }
+}
uint ne00;
uint ne20;
uint mode;
+ float alpha;
+ float limit;
} p;
buf_a[buf_idx + 1 ] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]) * d;
buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)]) * d;
buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_iq4nl[vui >> 12]) * d;
+#elif defined(DATA_A_MXFP4)
+ const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
+ const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + 2 * loadr_a;
+
+ const uint ib = idx / 8;
+ const uint iqs = (idx & 0x07) * 2;
+
+ const float d = e8m0_to_fp32(data_a[ib].e);
+ const uint vui = uint(data_a[ib].qs[iqs]);
+ const uint vui2 = uint(data_a[ib].qs[iqs+1]);
+
+ buf_a[buf_idx ] = FLOAT_TYPE(kvalues_mxfp4[vui & 0xF] * d);
+ buf_a[buf_idx + 16] = FLOAT_TYPE(kvalues_mxfp4[vui >> 4] * d);
+ buf_a[buf_idx + 1] = FLOAT_TYPE(kvalues_mxfp4[vui2 & 0xF] * d);
+ buf_a[buf_idx + 17] = FLOAT_TYPE(kvalues_mxfp4[vui2 >> 4] * d);
#endif
}
[[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
}
#endif
+#if defined(DATA_A_MXFP4)
+FLOAT_TYPE get_d(uint ib) {
+ return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e));
+}
+#endif
+
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
FLOAT_TYPE_VEC2 get_dm(uint ib) {
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
float m1;
uint n_head_log2;
uint nrows_x;
+ uint has_sinks;
} p;
#include "types.comp"
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
-layout (binding = 2) buffer D {D_TYPE data_d[];};
+layout (binding = 2) readonly buffer Z {float data_c[];};
+layout (binding = 3) buffer D {D_TYPE data_d[];};
shared FLOAT_TYPE vals[BLOCK_SIZE];
const uint h = (rowx / p.ne01) % p.ne02; // head index
const float base = h < p.n_head_log2 ? p.m0 : p.m1;
- const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1;
+ const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1;
slope = pow(base, exp);
}
// Find max
- FLOAT_TYPE max_val = uintBitsToFloat(0xFF800000);
+ FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02];
// Cache values while we compute the max, so we don't need to read them
// again when we're ready to compute exp(x-max).
}
sum = vals[0];
+ if (p.has_sinks != 0) {
+ sum += FLOAT_TYPE(exp(FLOAT_TYPE(data_c[i02]) - max_val));
+ }
+
FLOAT_TYPE rcpdivisor = 1.0/sum;
[[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
--- /dev/null
+#version 450
+
+#include "glu_head.comp"
+
+float op(float a, float b) {
+ float xi = min(a, p.limit);
+ float gi = max(min(b, p.limit), -p.limit);
+
+ float out_glu = xi / (1.0f + exp(-xi * p.alpha));
+ out_glu = out_glu * (1.0f + gi);
+ return out_glu;
+}
+
+#include "glu_main.comp"
#define A_TYPE_PACKED16 block_iq4_nl_packed16
#endif
+#define QUANT_K_MXFP4 32
+#define QUANT_R_MXFP4 2
+
+struct block_mxfp4
+{
+ uint8_t e;
+ uint8_t qs[QUANT_K_MXFP4/2];
+};
+
+//struct block_mxfp4_packed16
+//{
+// uint8_t e;
+// uint16_t qs[QUANT_K_MXFP4/2/2];
+//};
+
+#if defined(DATA_A_MXFP4)
+#define QUANT_K QUANT_K_MXFP4
+#define QUANT_R QUANT_R_MXFP4
+#define QUANT_AUXF 1
+#define A_TYPE block_mxfp4
+//#define A_TYPE_PACKED16 block_mxfp4_packed16
+#endif
+
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS)
const int8_t kvalues_iq4nl_const[16] = {
int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10),
}
#endif
+#if defined(DATA_A_MXFP4)
+const FLOAT_TYPE kvalues_mxfp4_const[16] = {
+ FLOAT_TYPE(0.0f), FLOAT_TYPE(0.5f), FLOAT_TYPE(1.0f), FLOAT_TYPE(1.5f), FLOAT_TYPE(2.0f), FLOAT_TYPE(3.0f), FLOAT_TYPE(4.0f), FLOAT_TYPE(6.0f),
+ FLOAT_TYPE(-0.0f), FLOAT_TYPE(-0.5f), FLOAT_TYPE(-1.0f), FLOAT_TYPE(-1.5f), FLOAT_TYPE(-2.0f), FLOAT_TYPE(-3.0f), FLOAT_TYPE(-4.0f), FLOAT_TYPE(-6.0f)
+};
+
+shared FLOAT_TYPE kvalues_mxfp4[16];
+
+#define NEEDS_INIT_IQ_SHMEM
+void init_iq_shmem(uvec3 wgsize)
+{
+ // copy the table into shared memory and sync
+ for (uint i = gl_LocalInvocationIndex.x; i < kvalues_mxfp4.length(); i += wgsize.x) {
+ kvalues_mxfp4[i] = kvalues_mxfp4_const[i];
+ }
+ barrier();
+}
+#endif
+
// returns the bfloat value in the low 16b.
// See ggml_compute_fp32_to_bf16
uint32_t fp32_to_bf16(float f)
return uintBitsToFloat(u << 16);
}
+float e8m0_to_fp32(uint8_t x) {
+ uint32_t bits;
+
+ if (x == 0) {
+ bits = 0x00400000;
+ } else {
+ bits = x;
+ bits = bits << 23;
+ }
+
+ return uintBitsToFloat(bits);
+}
+
#endif // !defined(GGML_TYPES_COMP)
"iq3_s",
"iq4_xs",
"iq4_nl",
+ "mxfp4",
"bf16",
};
CloseHandle(pi.hProcess);
CloseHandle(pi.hThread);
#else
-int stdout_pipe[2];
+ int stdout_pipe[2];
int stderr_pipe[2];
if (pipe(stdout_pipe) != 0 || pipe(stderr_pipe) != 0) {
std::string load_vec_quant = "2";
if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
load_vec_quant = "8";
- else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl"))
+ else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4"))
load_vec_quant = "4";
if (tname == "bf16") {
string_to_spv("reglu_f32" + suffix, "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("swiglu_f16" + suffix, "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("swiglu_f32" + suffix, "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
+ string_to_spv("swiglu_oai_f16" + suffix, "swiglu_oai.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
+ string_to_spv("swiglu_oai_f32" + suffix, "swiglu_oai.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("geglu_erf_f16" + suffix, "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("geglu_erf_f32" + suffix, "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("geglu_quick_f16" + suffix,"geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}});
string_to_spv("roll_f32", "roll.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
+ string_to_spv("add_id_f32", "add_id.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
+
for (auto &c : compiles) {
c.wait();
}
#endif
}
-static void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * GGML_RESTRICT x, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc);
-static void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc);
-static void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * GGML_RESTRICT x, size_t bx, ggml_bf16_t * GGML_RESTRICT y, size_t by, int nrc);
static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
[GGML_TYPE_I8] = {
.is_quantized = true,
.from_float_ref = (ggml_from_float_t) quantize_row_q8_1_ref,
},
+ [GGML_TYPE_MXFP4] = {
+ .type_name = "mxfp4",
+ .blck_size = QK_MXFP4,
+ .type_size = sizeof(block_mxfp4),
+ .is_quantized = true,
+ .to_float = (ggml_to_float_t) dequantize_row_mxfp4,
+ .from_float_ref = (ggml_from_float_t)quantize_row_mxfp4_ref,
+ },
[GGML_TYPE_Q2_K] = {
.type_name = "q2_K",
.blck_size = QK_K,
"DUP",
"ADD",
+ "ADD_ID",
"ADD1",
"ACC",
"SUB",
"GLU",
};
-static_assert(GGML_OP_COUNT == 86, "GGML_OP_COUNT != 86");
+static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
"x",
"x+y",
+ "x[i]+y",
"x+y",
"view(x,nb,offset)+=y->x",
"x-y",
"glu(x)",
};
-static_assert(GGML_OP_COUNT == 86, "GGML_OP_COUNT != 86");
+static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
"REGLU",
"GEGLU",
"SWIGLU",
+ "SWIGLU_OAI",
"GEGLU_ERF",
"GEGLU_QUICK",
};
-static_assert(GGML_GLU_OP_COUNT == 5, "GGML_GLU_OP_COUNT != 5");
+static_assert(GGML_GLU_OP_COUNT == 6, "GGML_GLU_OP_COUNT != 6");
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break;
case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break;
case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break;
+ case GGML_FTYPE_MOSTLY_MXFP4: wtype = GGML_TYPE_MXFP4; break;
case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break;
case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break;
case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break;
return ggml_add_cast_impl(ctx, a, b, type);
}
+struct ggml_tensor * ggml_add_id(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ struct ggml_tensor * ids) {
+
+ GGML_ASSERT(a->ne[0] == b->ne[0]);
+ GGML_ASSERT(a->ne[1] == ids->ne[0]);
+ GGML_ASSERT(a->ne[2] == ids->ne[1]);
+ GGML_ASSERT(ids->type == GGML_TYPE_I32);
+
+ struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
+
+ result->op = GGML_OP_ADD_ID;
+ result->src[0] = a;
+ result->src[1] = b;
+ result->src[2] = ids;
+
+ return result;
+}
+
// ggml_add1
static struct ggml_tensor * ggml_add1_impl(
return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU_QUICK, false);
}
+struct ggml_tensor * ggml_swiglu_oai(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ struct ggml_tensor * b,
+ float alpha,
+ float limit) {
+ struct ggml_tensor * result = ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU_OAI, false);
+ ggml_set_op_params_f32(result, 2, alpha);
+ ggml_set_op_params_f32(result, 3, limit);
+
+ return result;
+}
+
// ggml_norm
static struct ggml_tensor * ggml_norm_impl(
return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
}
+void ggml_soft_max_add_sinks(
+ struct ggml_tensor * a,
+ struct ggml_tensor * sinks) {
+ if (!sinks) {
+ a->src[2] = NULL;
+ return;
+ }
+
+ GGML_ASSERT(a->op == GGML_OP_SOFT_MAX);
+ GGML_ASSERT(a->src[2] == NULL);
+ GGML_ASSERT(a->src[0]->ne[2] == sinks->ne[0]);
+ GGML_ASSERT(sinks->type == GGML_TYPE_F32);
+
+ a->src[2] = sinks;
+}
+
// ggml_soft_max_ext_back
static struct ggml_tensor * ggml_soft_max_ext_back_impl(
return (enum ggml_prec) prec_i32;
}
+void ggml_flash_attn_ext_add_sinks(
+ struct ggml_tensor * a,
+ struct ggml_tensor * sinks) {
+ if (!sinks) {
+ a->src[4] = NULL;
+ return;
+ }
+
+ GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
+ GGML_ASSERT(a->src[4] == NULL);
+ GGML_ASSERT(a->src[0]->ne[2] == sinks->ne[0]);
+ GGML_ASSERT(sinks->type == GGML_TYPE_F32);
+
+ a->src[4] = sinks;
+}
+
// ggml_flash_attn_back
struct ggml_tensor * ggml_flash_attn_back(
case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
+ case GGML_TYPE_MXFP4: result = quantize_mxfp4(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
HUNYUAN_MOE = auto()
HUNYUAN_DENSE = auto()
SMOLLM3 = auto()
+ GPT_OSS = auto()
LFM2 = auto()
DREAM = auto()
SMALLTHINKER = auto()
ATTN_OUT_NORM = auto()
ATTN_POST_NORM = auto()
ATTN_ROT_EMBD = auto()
+ ATTN_SINKS = auto()
FFN_GATE_INP = auto()
FFN_GATE_INP_SHEXP = auto()
FFN_NORM = auto()
MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe",
MODEL_ARCH.HUNYUAN_DENSE: "hunyuan-dense",
MODEL_ARCH.SMOLLM3: "smollm3",
+ MODEL_ARCH.GPT_OSS: "gpt-oss",
MODEL_ARCH.LFM2: "lfm2",
MODEL_ARCH.DREAM: "dream",
MODEL_ARCH.SMALLTHINKER: "smallthinker",
MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v",
MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output",
MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd",
+ MODEL_TENSOR.ATTN_SINKS: "blk.{bid}.attn_sinks",
MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm",
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
+ MODEL_ARCH.GPT_OSS: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_POST_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_SINKS,
+ MODEL_TENSOR.FFN_GATE_INP,
+ MODEL_TENSOR.FFN_GATE_EXP,
+ MODEL_TENSOR.FFN_DOWN_EXP,
+ MODEL_TENSOR.FFN_UP_EXP,
+ ],
MODEL_ARCH.LFM2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.TOKEN_EMBD_NORM,
BF16 = 30
TQ1_0 = 34
TQ2_0 = 35
+ MXFP4 = 39
class ExpertGatingFuncType(IntEnum):
GGMLQuantizationType.BF16: (1, 2),
GGMLQuantizationType.TQ1_0: (256, 2 + 4 * 13),
GGMLQuantizationType.TQ2_0: (256, 2 + 64),
+ GGMLQuantizationType.MXFP4: (32, 1 + 16),
}
size = prod(shape)
if "_exps." in name:
- expert_params += (size // shape[-3])
- expert_sum += shape[-3]
+ expert_count = shape[-2 if ".bias" in name else -3]
+ expert_params += (size // expert_count)
+ expert_sum += expert_count
n_expert_tensors += 1
else:
shared_params += size
"transformer.h.{bid}.attn.rotary_emb.inv_freq", # codeshell
),
+ MODEL_TENSOR.ATTN_SINKS: (
+ "model.layers.{bid}.self_attn.sinks", # openai-moe
+ ),
+
# Feed-forward norm
MODEL_TENSOR.FFN_NORM: (
"gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
"model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
"model.layers.{bid}.feed_forward.router", # llama4 jamba
"encoder.layers.{bid}.mlp.router.layer", # nomic-bert-moe
+ "model.layers.{bid}.mlp.router", # openai-moe
"model.layers.{bid}.mlp.gate.wg", # hunyuan
"model.layers.{bid}.block_sparse_moe.primary_router", # smallthinker
),
//LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // removed from gguf files, use Q4_0 and runtime repack
LLAMA_FTYPE_MOSTLY_TQ1_0 = 36, // except 1d tensors
LLAMA_FTYPE_MOSTLY_TQ2_0 = 37, // except 1d tensors
+ LLAMA_FTYPE_MOSTLY_MXFP4_MOE = 38, // except 1d tensors
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
};
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
{ LLM_ARCH_HUNYUAN_DENSE, "hunyuan-dense" },
{ LLM_ARCH_SMOLLM3, "smollm3" },
+ { LLM_ARCH_OPENAI_MOE, "gpt-oss" },
{ LLM_ARCH_LFM2, "lfm2" },
{ LLM_ARCH_DREAM, "dream" },
{ LLM_ARCH_SMALLTHINKER, "smallthinker" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
+ {
+ LLM_ARCH_OPENAI_MOE,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_OUTPUT, "output" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_ATTN_SINKS, "blk.%d.attn_sinks" },
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
+ { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
+ { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
+ { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
+ },
+ },
{
LLM_ARCH_LFM2,
{
{LLM_TENSOR_ATTN_KV_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_K_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_V_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
+ {LLM_TENSOR_ATTN_SINKS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SCALE}},
{LLM_TENSOR_DEC_ATTN_Q, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_DEC_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
LLM_ARCH_HUNYUAN_MOE,
LLM_ARCH_HUNYUAN_DENSE,
LLM_ARCH_SMOLLM3,
+ LLM_ARCH_OPENAI_MOE,
LLM_ARCH_LFM2,
LLM_ARCH_DREAM,
LLM_ARCH_SMALLTHINKER,
LLM_TENSOR_ATTN_OUT_NORM,
LLM_TENSOR_ATTN_POST_NORM,
LLM_TENSOR_ATTN_ROT_EMBD,
+ LLM_TENSOR_ATTN_SINKS,
LLM_TENSOR_FFN_GATE_INP,
LLM_TENSOR_FFN_GATE_INP_SHEXP,
LLM_TENSOR_FFN_NORM,
{ "llama4", LLM_CHAT_TEMPLATE_LLAMA4 },
{ "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM },
{ "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE },
+ { "gpt-oss", LLM_CHAT_TEMPLATE_OPENAI_MOE },
{ "hunyuan-dense", LLM_CHAT_TEMPLATE_HUNYUAN_DENSE },
{ "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
};
return LLM_CHAT_TEMPLATE_DOTS1;
} else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) {
return LLM_CHAT_TEMPLATE_HUNYUAN_MOE;
+ } else if (tmpl_contains("<|start|>") && tmpl_contains("<|channel|>")) {
+ return LLM_CHAT_TEMPLATE_OPENAI_MOE;
} else if (tmpl_contains("<|hy_place▁holder▁no▁2|>") && tmpl_contains("<|hy_place▁holder▁no▁3|>")) {
return LLM_CHAT_TEMPLATE_HUNYUAN_DENSE;
} else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) {
ss << "<|startoftext|>" << message->content << "<|extra_0|>";
}
}
+ } else if (tmpl == LLM_CHAT_TEMPLATE_OPENAI_MOE) {
+ // OpenAI MoE (based on Harmony chat template)
+ for (auto message : chat) {
+ std::string role(message->role);
+ ss << "<|start|>" << role << "<|message|>" << message->content;
+ ss << (role == "assistant" ? "<|return|>" : "<|end|>");
+ }
+ if (add_ass) {
+ ss << "<|start|>assistant";
+ }
} else if (tmpl == LLM_CHAT_TEMPLATE_HUNYUAN_DENSE) {
// tencent/Hunyuan-4B-Instruct
for (size_t i = 0; i < chat.size(); i++) {
LLM_CHAT_TEMPLATE_SMOLVLM,
LLM_CHAT_TEMPLATE_DOTS1,
LLM_CHAT_TEMPLATE_HUNYUAN_MOE,
+ LLM_CHAT_TEMPLATE_OPENAI_MOE,
LLM_CHAT_TEMPLATE_HUNYUAN_DENSE,
LLM_CHAT_TEMPLATE_KIMI_K2,
LLM_CHAT_TEMPLATE_UNKNOWN,
cur = ggml_reglu(ctx0, cur);
cb(cur, "ffn_reglu", il);
} break;
+ default:
+ GGML_ABORT("fatal error");
}
if (gate && type_gate == LLM_FFN_PAR) {
llama_expert_gating_func_type gating_op,
int il,
ggml_tensor * probs_in) const {
+ return build_moe_ffn(
+ cur,
+ gate_inp, /* gate_inp_b */ nullptr,
+ up_exps, /* up_exps_b */ nullptr,
+ gate_exps, /* gate_exps_b */ nullptr,
+ down_exps, /* down_exps_b */ nullptr,
+ exp_probs_b,
+ n_expert,
+ n_expert_used,
+ type_op,
+ norm_w,
+ scale_w,
+ w_scale,
+ gating_op,
+ il,
+ probs_in
+ );
+}
+
+ggml_tensor * llm_graph_context::build_moe_ffn(
+ ggml_tensor * cur,
+ ggml_tensor * gate_inp,
+ ggml_tensor * gate_inp_b,
+ ggml_tensor * up_exps,
+ ggml_tensor * up_exps_b,
+ ggml_tensor * gate_exps,
+ ggml_tensor * gate_exps_b,
+ ggml_tensor * down_exps,
+ ggml_tensor * down_exps_b,
+ ggml_tensor * exp_probs_b,
+ int64_t n_expert,
+ int64_t n_expert_used,
+ llm_ffn_op_type type_op,
+ bool norm_w,
+ bool scale_w,
+ float w_scale,
+ llama_expert_gating_func_type gating_op,
+ int il,
+ ggml_tensor * probs_in) const {
const int64_t n_embd = cur->ne[0];
const int64_t n_tokens = cur->ne[1];
const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
logits = probs_in;
}
+ if (gate_inp_b) {
+ logits = ggml_add(ctx0, logits, gate_inp_b);
+ cb(logits, "ffn_moe_logits_biased", il);
+ }
+
ggml_tensor * probs = nullptr;
switch (gating_op) {
case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX:
{
probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
} break;
+ case LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT:
+ {
+ probs = logits; // [n_expert, n_tokens]
+ } break;
default:
GGML_ABORT("fatal error");
}
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
cb(weights, "ffn_moe_weights", il);
+ if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) {
+ weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
+ weights = ggml_soft_max(ctx0, weights); // [n_expert_used, n_tokens]
+ weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
+ cb(weights, "ffn_moe_weights_softmax", il);
+ }
+
if (norm_w) {
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cb(up, "ffn_moe_up", il);
+ if (up_exps_b) {
+ up = ggml_add_id(ctx0, up, up_exps_b, selected_experts);
+ cb(up, "ffn_moe_up_biased", il);
+ }
+
ggml_tensor * experts = nullptr;
if (gate_exps) {
cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cur = up;
}
+ if (gate_exps_b) {
+ cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts);
+ cb(cur, "ffn_moe_gate_biased", il);
+ }
+
switch (type_op) {
case LLM_FFN_SILU:
if (gate_exps) {
cur = ggml_gelu(ctx0, cur);
cb(cur, "ffn_moe_gelu", il);
} break;
+ case LLM_FFN_SWIGLU_OAI_MOE:
+ {
+ // TODO: move to hparams?
+ constexpr float alpha = 1.702f;
+ constexpr float limit = 7.0f;
+ cur = ggml_swiglu_oai(ctx0, cur, up, alpha, limit);
+ cb(cur, "ffn_moe_swiglu_oai", il);
+ } break;
case LLM_FFN_RELU:
if (gate_exps) {
cur = ggml_reglu_split(ctx0, cur, up);
experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
cb(experts, "ffn_moe_down", il);
+ if (down_exps_b) {
+ experts = ggml_add_id(ctx0, experts, down_exps_b, selected_experts);
+ cb(experts, "ffn_moe_down_biased", il);
+ }
+
if (!weight_before_ffn) {
experts = ggml_mul(ctx0, experts, weights);
cb(cur, "ffn_moe_weighted", il);
ggml_tensor * kq_b,
ggml_tensor * kq_mask,
ggml_tensor * v_mla,
+ ggml_tensor * sinks,
float kq_scale) const {
const bool v_trans = v->nb[1] > v->nb[2];
cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
- ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
+ ggml_flash_attn_ext_add_sinks(cur, sinks);
+ ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
if (v_mla) {
#if 0
}
kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
+ ggml_soft_max_add_sinks(kq, sinks);
if (!v_trans) {
// note: avoid this branch
ggml_tensor * k = k_cur;
ggml_tensor * v = v_cur;
- ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale);
cb(cur, "kqv_out", il);
if (wo) {
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
- ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale);
cb(cur, "kqv_out", il);
if (wo) {
ggml_tensor * v_mla,
float kq_scale,
int il) const {
+ return build_attn_with_sinks(
+ inp,
+ wo,
+ wo_b,
+ q_cur,
+ k_cur,
+ v_cur,
+ kq_b,
+ v_mla,
+ nullptr,
+ kq_scale,
+ il);
+}
+
+ggml_tensor * llm_graph_context::build_attn_with_sinks(
+ llm_graph_input_attn_kv_unified_iswa * inp,
+ ggml_tensor * wo,
+ ggml_tensor * wo_b,
+ ggml_tensor * q_cur,
+ ggml_tensor * k_cur,
+ ggml_tensor * v_cur,
+ ggml_tensor * kq_b,
+ ggml_tensor * v_mla,
+ ggml_tensor * sinks,
+ float kq_scale,
+ int il) const {
// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
ggml_build_forward_expand(gf, q_cur);
ggml_tensor * k = mctx_cur->get_k(ctx0, il);
ggml_tensor * v = mctx_cur->get_v(ctx0, il);
- ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, sinks, kq_scale);
cb(cur, "kqv_out", il);
if (wo) {
ggml_tensor * k = k_cur;
ggml_tensor * v = v_cur;
- ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale);
+ ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, nullptr, kq_scale);
cb(cur, "kqv_out", il);
if (wo) {
LLM_FFN_SWIGLU,
LLM_FFN_GEGLU,
LLM_FFN_REGLU,
+ LLM_FFN_SWIGLU_OAI_MOE,
};
enum llm_ffn_gate_type {
llm_ffn_gate_type type_gate,
int il) const;
+ // build MoE FFN without bias tensors
ggml_tensor * build_moe_ffn(
ggml_tensor * cur,
ggml_tensor * gate_inp,
int il,
ggml_tensor * probs_in = nullptr) const;
+ ggml_tensor * build_moe_ffn(
+ ggml_tensor * cur,
+ ggml_tensor * gate_inp,
+ ggml_tensor * gate_inp_b,
+ ggml_tensor * up_exps,
+ ggml_tensor * up_exps_b,
+ ggml_tensor * gate_exps,
+ ggml_tensor * gate_exps_b,
+ ggml_tensor * down_exps,
+ ggml_tensor * down_exps_b,
+ ggml_tensor * exp_probs_b,
+ int64_t n_expert,
+ int64_t n_expert_used,
+ llm_ffn_op_type type_op,
+ bool norm_w,
+ bool scale_w,
+ float w_scale,
+ llama_expert_gating_func_type gating_op,
+ int il,
+ ggml_tensor * probs_in = nullptr) const;
+
//
// inputs
//
ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false)
ggml_tensor * kq_b,
ggml_tensor * kq_mask,
+ ggml_tensor * sinks,
ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
float kq_scale) const;
float kq_scale,
int il) const;
+ // TODO: temporary to keep the diff small. after the code is public will refactor to simplify this
+ ggml_tensor * build_attn_with_sinks(
+ llm_graph_input_attn_kv_unified_iswa * inp,
+ ggml_tensor * wo,
+ ggml_tensor * wo_b,
+ ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
+ ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] optional
+ ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] optional
+ ggml_tensor * kq_b,
+ ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v]
+ ggml_tensor * sinks, // [n_head_q]
+ float kq_scale,
+ int il) const;
+
llm_graph_input_attn_cross * build_attn_inp_cross() const;
ggml_tensor * build_attn(
#define LLAMA_MAX_EXPERTS 384 // Kimi-K2
enum llama_expert_gating_func_type {
- LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0,
- LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX = 1,
- LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2,
+ LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0,
+ LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX = 1,
+ LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID = 2,
+ LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT = 3, // applied to the router weights instead of the logits
};
enum llama_swa_type {
case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0";
case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1";
case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0";
+ case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: return "MXFP4 MoE";
case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium";
case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small";
case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small";
ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]);
op_tensor = ggml_add(ctx, a, w);
} break;
+ case GGML_OP_ADD_ID:
+ {
+ int n_expert_used = hparams.n_expert_used;
+ ggml_tensor * a = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0], n_expert_used, 512);
+ ggml_tensor * c = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_expert_used, 512);
+ op_tensor = ggml_add_id(ctx, a, w, c);
+ } break;
case GGML_OP_MUL:
{
ggml_tensor * a = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, w->ne[0], w->ne[1], w->ne[2], w->ne[3]);
ggml_tensor * b = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_embd, w->ne[1], 1, 1);
op_tensor = ggml_im2col(ctx, w, b, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F16);
} break;
+ case GGML_OP_SCALE:
+ {
+ op_tensor = ggml_scale(ctx, w, 1.0f);
+ } break;
default:
GGML_ABORT("%s: missing test for op %s for tensor %s", __func__, ggml_op_name(op), w->name);
}
default: type = LLM_TYPE_UNKNOWN;
}
} break;
+ case LLM_ARCH_OPENAI_MOE:
+ {
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+ ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
+ ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
+
+ hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
+ hparams.set_swa_pattern(2);
+
+ // TODO: switch (hparams.n_layer)
+ } break;
case LLM_ARCH_LFM2:
{
ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache);
return nullptr;
}
- // tensors with "bias" suffix are always used with GGML_OP_ADD
+ // tensors with "bias" suffix are always used with GGML_OP_ADD or GGML_OP_ADD_ID
ggml_op op;
bool bias = tn.suffix != nullptr && strcmp(tn.suffix, "bias") == 0;
if (bias) {
- op = GGML_OP_ADD;
+ if (info.op == GGML_OP_MUL_MAT_ID) {
+ op = GGML_OP_ADD_ID;
+ } else {
+ op = GGML_OP_ADD;
+ }
} else {
op = info.op;
}
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}
} break;
+ case LLM_ARCH_OPENAI_MOE:
+ {
+ const int64_t n_ff_exp = hparams.n_ff_exp;
+
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+ // output
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
+
+ for (int i = 0; i < n_layer; ++i) {
+ auto & layer = layers[i];
+
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+ layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
+
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_rot}, 0);
+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_head_kv * n_rot}, 0);
+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_head_kv * n_rot}, 0);
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0);
+
+ layer.attn_sinks = create_tensor(tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, 0);
+
+ layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert}, 0);
+ layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
+ layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
+ layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
+
+ // bias
+ layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_head * n_rot}, 0);
+ layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_head_kv * n_rot}, 0);
+ layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_head_kv * n_rot}, 0);
+ layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
+
+ layer.ffn_gate_inp_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "bias", i), {n_expert}, 0);
+ layer.ffn_gate_exps_b = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "bias", i), {n_ff_exp, n_expert}, 0);
+ layer.ffn_down_exps_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "bias", i), { n_embd, n_expert}, 0);
+ layer.ffn_up_exps_b = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "bias", i), {n_ff_exp, n_expert}, 0);
+ }
+ } break;
case LLM_ARCH_LFM2:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
}
- if (arch == LLM_ARCH_QWEN3MOE) {
+ if (arch == LLM_ARCH_QWEN3MOE || arch == LLM_ARCH_OPENAI_MOE) {
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
}
}
};
+struct llm_build_openai_moe_iswa : public llm_graph_context {
+ llm_build_openai_moe_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+ ggml_tensor * cur;
+ ggml_tensor * inpL;
+
+ inpL = build_inp_embd(model.tok_embd);
+
+ // inp_pos - contains the positions
+ ggml_tensor * inp_pos = build_inp_pos();
+
+ auto * inp_attn = build_attn_inp_kv_unified_iswa();
+
+ for (int il = 0; il < n_layer; ++il) {
+ ggml_tensor * inpSA = inpL;
+
+ // norm
+ cur = build_norm(inpL,
+ model.layers[il].attn_norm, nullptr,
+ LLM_NORM_RMS, il);
+ cb(cur, "attn_norm", il);
+
+ // self-attention
+ {
+ // compute Q and K and RoPE them
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+ cb(Qcur, "Qcur", il);
+ if (model.layers[il].bq) {
+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+ cb(Qcur, "Qcur", il);
+ }
+
+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+ cb(Kcur, "Kcur", il);
+ if (model.layers[il].bk) {
+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+ cb(Kcur, "Kcur", il);
+ }
+
+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+ cb(Vcur, "Vcur", il);
+ if (model.layers[il].bv) {
+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+ cb(Vcur, "Vcur", il);
+ }
+
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens);
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens);
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_rot, n_head_kv, n_tokens);
+
+ Qcur = ggml_rope_ext(
+ ctx0, Qcur, inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+
+ Kcur = ggml_rope_ext(
+ ctx0, Kcur, inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+
+ cb(Qcur, "Qcur", il);
+ cb(Kcur, "Kcur", il);
+ cb(Vcur, "Vcur", il);
+
+ cur = build_attn_with_sinks(inp_attn,
+ model.layers[il].wo, model.layers[il].bo,
+ Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].attn_sinks, 1.0f/sqrtf(float(n_rot)), il);
+
+ cb(cur, "attn_out", il);
+ }
+
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
+ ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+ cb(ffn_inp, "ffn_inp", il);
+
+ cur = ffn_inp;
+ cur = build_norm(cur,
+ model.layers[il].attn_post_norm, nullptr,
+ LLM_NORM_RMS, il);
+ cb(cur, "attn_post_norm", il);
+
+ // MoE branch
+ cur = build_moe_ffn(cur,
+ model.layers[il].ffn_gate_inp, model.layers[il].ffn_gate_inp_b,
+ model.layers[il].ffn_up_exps, model.layers[il].ffn_up_exps_b,
+ model.layers[il].ffn_gate_exps, model.layers[il].ffn_gate_exps_b,
+ model.layers[il].ffn_down_exps, model.layers[il].ffn_down_exps_b,
+ nullptr,
+ n_expert, n_expert_used,
+ LLM_FFN_SWIGLU_OAI_MOE, false,
+ false, 0.0,
+ LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT,
+ il);
+ cb(cur, "ffn_moe_out", il);
+
+ cur = ggml_add(ctx0, cur, ffn_inp);
+
+ cur = build_cvec(cur, il);
+ cb(cur, "l_out", il);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ cur = inpL;
+
+ cur = build_norm(cur,
+ model.output_norm, NULL,
+ LLM_NORM_RMS, -1);
+
+ cb(cur, "result_norm", -1);
+ res->t_embd = cur;
+
+ // lm_head
+ cur = build_lora_mm(model.output, cur);
+
+ cb(cur, "result_output", -1);
+ res->t_logits = cur;
+
+ ggml_build_forward_expand(gf, cur);
+ }
+};
+
struct llm_build_lfm2 : public llm_graph_context {
const llama_model & model;
{
llm = std::make_unique<llm_build_smollm3>(*this, params);
} break;
+ case LLM_ARCH_OPENAI_MOE:
+ {
+ llm = std::make_unique<llm_build_openai_moe_iswa>(*this, params);
+ } break;
case LLM_ARCH_FALCON_H1:
{
llm = std::make_unique<llm_build_falcon_h1>(*this, params);
case LLM_ARCH_MINICPM3:
case LLM_ARCH_DOTS1:
case LLM_ARCH_HUNYUAN_MOE:
+ case LLM_ARCH_OPENAI_MOE:
case LLM_ARCH_HUNYUAN_DENSE:
case LLM_ARCH_LFM2:
case LLM_ARCH_SMALLTHINKER:
struct ggml_tensor * ffn_up_enc = nullptr;
// ff MoE
- struct ggml_tensor * ffn_gate_inp = nullptr;
- struct ggml_tensor * ffn_gate_exps = nullptr;
- struct ggml_tensor * ffn_down_exps = nullptr;
- struct ggml_tensor * ffn_up_exps = nullptr;
+ struct ggml_tensor * ffn_gate_inp = nullptr;
+ struct ggml_tensor * ffn_gate_exps = nullptr;
+ struct ggml_tensor * ffn_down_exps = nullptr;
+ struct ggml_tensor * ffn_up_exps = nullptr;
+ struct ggml_tensor * ffn_gate_inp_b = nullptr;
+ struct ggml_tensor * ffn_gate_exps_b = nullptr;
+ struct ggml_tensor * ffn_down_exps_b = nullptr;
+ struct ggml_tensor * ffn_up_exps_b = nullptr;
// ff shared expert (shexp)
struct ggml_tensor * ffn_gate_inp_shexp = nullptr;
struct ggml_tensor * laurel_r = nullptr;
struct ggml_tensor * laurel_post_norm = nullptr;
+ // openai-moe
+ struct ggml_tensor * attn_sinks = nullptr;
+
struct llama_layer_posnet posnet;
struct llama_layer_convnext convnext;
const int64_t nx = tensor->ne[0];
const int64_t qk_k = ggml_blck_size(new_type);
- if (arch == LLM_ARCH_FALCON || nx % qk_k != 0) {
+ if (ftype == LLAMA_FTYPE_MOSTLY_MXFP4_MOE) {
+ new_type = GGML_TYPE_Q8_0;
+ }
+ else if (arch == LLM_ARCH_FALCON || nx % qk_k != 0) {
new_type = GGML_TYPE_Q8_0;
}
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS ||
new_type = GGML_TYPE_Q6_K;
}
}
+ } else if (ftype == LLAMA_FTYPE_MOSTLY_MXFP4_MOE) {
+ // MoE tensors -> MXFP4
+ // other tensors -> Q8_0
+ if (tensor->ne[2] > 1) {
+ new_type = GGML_TYPE_MXFP4;
+ } else {
+ new_type = GGML_TYPE_Q8_0;
+ }
} else if (name == "token_embd.weight" || name == "per_layer_token_embd.weight") {
if (qs.params->token_embedding_type < GGML_TYPE_COUNT) {
new_type = qs.params->token_embedding_type;
case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break;
case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break;
+ case LLAMA_FTYPE_MOSTLY_MXFP4_MOE: default_type = GGML_TYPE_MXFP4; break;
+
// K-quants
case LLAMA_FTYPE_MOSTLY_Q2_K_S:
case LLAMA_FTYPE_MOSTLY_Q2_K: default_type = GGML_TYPE_Q2_K; break;
const float * imatrix_03 = imatrix ? imatrix + i03 * n_per_row : nullptr;
new_size += llama_tensor_quantize_impl(new_type, f32_data_03, new_data_03, chunk_size, nrows, n_per_row, imatrix_03, workers, nthread_use);
+
+ // TODO: temporary sanity check that the F16 -> MXFP4 is lossless
+#if 1
+ if (new_type == GGML_TYPE_MXFP4) {
+ auto * x = f32_data_03;
+
+ //LLAMA_LOG_INFO("nrows = %d, n_per_row = %d\n", nrows, n_per_row);
+ std::vector<float> deq(nrows*n_per_row);
+ const ggml_type_traits * qtype = ggml_get_type_traits(new_type);
+ qtype->to_float(new_data_03, deq.data(), deq.size());
+
+ double err = 0.0f;
+ for (int i = 0; i < (int) deq.size(); ++i) {
+ err += fabsf(deq[i] - x[i]);
+ //if (fabsf(deq[i] - x[i]) > 0.00001 && i < 256) {
+ if (deq[i] != x[i]) {
+ LLAMA_LOG_INFO("deq[%d] = %f, x[%d] = %f\n", i, deq[i], i, x[i]);
+ }
+ }
+ //LLAMA_LOG_INFO("err = %f\n", err);
+ GGML_ASSERT(err == 0.00000);
+ }
+#endif
}
LLAMA_LOG_INFO("size = %8.2f MiB -> %8.2f MiB\n", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0);
}
|| t.first == "<|eot_id|>"
|| t.first == "<|im_end|>"
|| t.first == "<|end|>"
+ || t.first == "<|return|>" // o200k_harmony
+ || t.first == "<|call|>" // o200k_harmony
|| t.first == "<end_of_turn>"
|| t.first == "<|endoftext|>"
|| t.first == "<|eom_id|>"
}
}
+ // @ngxson : quick hack for gpt-oss, always render these tokens
+ for (const auto & t : token_to_id) {
+ if (t.first == "<|channel|>" || t.first == "<|message|>" || t.first == "<|start|>") {
+ id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_USER_DEFINED;
+ }
+ }
+
// sanity checks
if (special_eos_id != LLAMA_TOKEN_NULL && special_eog_ids.count(special_eos_id) == 0) {
special_eog_ids.insert(special_eos_id);
special_eog_ids.insert(special_eom_id);
LLAMA_LOG_WARN("%s: special_eom_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
}
+
+ // TODO: workaround for o200k_harmony tokenizer: the "<|end|>" token should not be EOG
+ // we don't have a good way to detect this, so for now, if we have "<|return|>" and "<|call|>" tokens,
+ // we remove the "<|end|>" token from the EOG list
+ {
+ bool has_return = false;
+ bool has_call = false;
+ bool has_end = false;
+
+ llama_token end_id = LLAMA_TOKEN_NULL;
+
+ LLAMA_LOG_INFO("%s: printing all EOG tokens:\n", __func__);
+ for (auto tid : special_eog_ids) {
+ LLAMA_LOG_INFO("%s: - %d ('%s')\n", __func__, tid, id_to_token[tid].text.c_str());
+
+ if (id_to_token[tid].text == "<|return|>") {
+ has_return = true;
+ } else if (id_to_token[tid].text == "<|call|>") {
+ has_call = true;
+ } else if (id_to_token[tid].text == "<|end|>") {
+ has_end = true;
+ end_id = tid;
+ }
+ }
+
+ if (has_return && has_call && has_end) {
+ special_eog_ids.erase(end_id);
+ LLAMA_LOG_WARN("%s: special_eog_ids contains both '<|return|>' and '<|call|>' tokens, removing '<|end|>' token from EOG list\n", __func__);
+ }
+ }
}
// build special tokens cache
#define VARS_TO_STR10(a, b, c, d, e, f, g, h, i, j) VAR_TO_STR(a) + "," + VARS_TO_STR9(b, c, d, e, f, g, h, i, j)
#define VARS_TO_STR11(a, b, c, d, e, f, g, h, i, j, k) VAR_TO_STR(a) + "," + VARS_TO_STR10(b, c, d, e, f, g, h, i, j, k)
#define VARS_TO_STR12(a, b, c, d, e, f, g, h, i, j, k, l) VAR_TO_STR(a) + "," + VARS_TO_STR11(b, c, d, e, f, g, h, i, j, k, l)
+#define VARS_TO_STR13(a, b, c, d, e, f, g, h, i, j, k, l, m) VAR_TO_STR(a) + "," + VARS_TO_STR12(b, c, d, e, f, g, h, i, j, k, l, m)
#ifdef GGML_USE_SYCL
static bool inline _isinf(float f) {
}
};
+struct test_swiglu_oai : public test_case {
+ const ggml_type type;
+ const std::array<int64_t, 4> ne_a;
+ int v; // view (1 : non-contiguous a)
+ float alpha;
+ float limit;
+
+ std::string vars() override {
+ return VARS_TO_STR5(type, ne_a, v, alpha, limit);
+ }
+
+ test_swiglu_oai(ggml_type type = GGML_TYPE_F32,
+ std::array<int64_t, 4> ne_a = {128, 2, 2, 2},
+ int v = 0,
+ float alpha = 1.702f,
+ float limit = 7.0f)
+ : type(type), ne_a(ne_a), v(v), alpha(alpha), limit(limit) {}
+
+ ggml_tensor * build_graph(ggml_context * ctx) override {
+ ggml_tensor * a;
+ ggml_tensor * b;
+ if (v & 1) {
+ auto ne = ne_a; ne[0] *= 3;
+ a = ggml_new_tensor(ctx, type, 4, ne.data());
+ ggml_set_param(a);
+ ggml_set_name(a, "a");
+
+ a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
+ ggml_set_name(a, "view_of_a");
+
+ b = ggml_new_tensor(ctx, type, 4, ne.data());
+ ggml_set_param(b);
+ ggml_set_name(b, "b");
+
+ b = ggml_view_4d(ctx, b, ne_a[0], ne_a[1], ne_a[2], ne_a[3], b->nb[1], b->nb[2], b->nb[3], 0);
+ ggml_set_name(a, "view_of_b");
+ } else {
+ a = ggml_new_tensor(ctx, type, 4, ne_a.data());
+ ggml_set_param(a);
+ ggml_set_name(a, "a");
+
+ b = ggml_new_tensor(ctx, type, 4, ne_a.data());
+ ggml_set_param(b);
+ ggml_set_name(b, "b");
+ }
+
+ ggml_tensor * out = ggml_swiglu_oai(ctx, a, b, alpha, limit);
+ ggml_set_name(out, "out");
+
+ return out;
+ }
+
+ void initialize_tensors(ggml_context * ctx) override {
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+ // test extended range of values to check for NaNs in GELU
+ init_tensor_uniform(t, -150.f, 150.f);
+ }
+ }
+};
+
// GGML_OP_GET_ROWS
struct test_get_rows : public test_case {
const ggml_type type;
}
};
+// GGML_OP_ADD_ID
+struct test_add_id : public test_case {
+ const ggml_type type_a;
+ const ggml_type type_b;
+ const int64_t n_embd;
+ const int64_t n_experts;
+ const int64_t n_experts_used;
+ const int64_t n_token;
+
+ std::string vars() override {
+ return VARS_TO_STR6(type_a, type_b, n_embd, n_experts, n_experts_used, n_token);
+ }
+
+ size_t op_size(ggml_tensor * t) override {
+ return ggml_nbytes(t) + ggml_nbytes(t->src[0]) + ggml_nbytes(t->src[2]);
+ }
+
+ test_add_id(ggml_type type_a = GGML_TYPE_F32,
+ ggml_type type_b = GGML_TYPE_F32,
+ int64_t n_embd = 128,
+ int64_t n_experts = 16,
+ int64_t n_experts_used = 8,
+ int64_t n_token = 10)
+ : type_a(type_a), type_b(type_b), n_embd(n_embd),
+ n_experts(n_experts), n_experts_used(n_experts_used), n_token(n_token) {}
+
+ ggml_tensor * build_graph(ggml_context * ctx) override {
+ ggml_tensor * a = ggml_new_tensor_3d(ctx, type_a, n_embd, n_experts_used, n_token);
+ ggml_tensor * b = ggml_new_tensor_2d(ctx, type_b, n_embd, n_experts);
+ ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_experts, n_token);
+ if (n_experts_used != n_experts) {
+ ids = ggml_view_2d(ctx, ids, n_experts_used, n_token, ids->nb[1], 0);
+ ggml_set_name(ids, "view_of_ids");
+ }
+
+ ggml_tensor * out = ggml_add_id(ctx, a, b, ids);
+ ggml_set_name(out, "out");
+ return out;
+ }
+
+ void initialize_tensors(ggml_context * ctx) override {
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+ if (t->type == GGML_TYPE_I32) {
+ if (ggml_is_view_op(t->op)) { continue; }
+ std::random_device rd;
+ std::default_random_engine rng(rd());
+ // ids
+ for (int64_t r = 0; r < ggml_nrows(t); r++) {
+ std::vector<int32_t> data(t->ne[0]);
+ for (int i = 0; i < t->ne[0]; i++) {
+ data[i] = i % n_experts;
+ }
+ std::shuffle(data.begin(), data.end(), rng);
+ ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
+ }
+ } else {
+ init_tensor_uniform(t);
+ }
+ }
+ }
+};
+
// GGML_OP_ADD1
struct test_add1 : public test_case {
const ggml_type type;
}
void initialize_tensors(ggml_context * ctx) override {
- std::random_device rd;
- std::default_random_engine rng(rd());
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->type == GGML_TYPE_I32) {
if (ggml_is_view_op(t->op)) { continue; }
+ std::random_device rd;
+ std::default_random_engine rng(rd());
// ids
for (int64_t r = 0; r < ggml_nrows(t); r++) {
std::vector<int32_t> data(t->ne[0]);
const ggml_type type;
const std::array<int64_t, 4> ne;
const bool mask;
+ const bool sinks;
const ggml_type m_prec;
const std::array<int64_t, 2> nr23; // broadcast only dims 2 and 3
const float scale;
const float max_bias;
std::string vars() override {
- return VARS_TO_STR7(type, ne, mask, m_prec, nr23, scale, max_bias);
+ return VARS_TO_STR8(type, ne, mask, sinks, m_prec, nr23, scale, max_bias);
}
// the 1024 test with bias occasionally fails:
test_soft_max(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 5, 4, 3},
bool mask = false,
+ bool sinks = false,
ggml_type m_prec = GGML_TYPE_F32,
std::array<int64_t, 2> nr23 = {1, 1},
float scale = 1.0f,
float max_bias = 0.0f)
- : type(type), ne(ne), mask(mask), m_prec(m_prec), nr23(nr23), scale(scale), max_bias(max_bias) {}
+ : type(type), ne(ne), mask(mask), sinks(sinks), m_prec(m_prec), nr23(nr23), scale(scale), max_bias(max_bias) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2]*nr23[0], ne[3]*nr23[1]);
ggml_set_name(mask, "mask");
}
+ ggml_tensor * sinks = nullptr;
+ if (this->sinks) {
+ sinks = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[2]*nr23[0]);
+ ggml_set_name(sinks, "sinks");
+ }
+
ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, scale, max_bias);
+ ggml_soft_max_add_sinks(out, sinks);
ggml_set_name(out, "out");
return out;
const int64_t nb; // batch size
const bool mask; // use mask
+ const bool sinks; // use sinks
const float max_bias; // ALiBi
const float logit_softcap; // Gemma 2
std::array<int32_t, 4> permute;
std::string vars() override {
- return VARS_TO_STR12(hsk, hsv, nh, nr23, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute);
+ return VARS_TO_STR13(hsk, hsv, nh, nr23, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV, permute);
}
double max_nmse_err() override {
}
test_flash_attn_ext(int64_t hsk = 128, int64_t hsv = 128, int64_t nh = 32, std::array<int64_t, 2> nr23 = {1, 1}, int64_t kv = 96, int64_t nb = 8,
- bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32,
+ bool mask = true, bool sinks = false, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32,
ggml_type type_KV = GGML_TYPE_F16, std::array<int32_t, 4> permute = {0, 1, 2, 3})
- : hsk(hsk), hsv(hsv), nh(nh), nr23(nr23), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
+ : hsk(hsk), hsv(hsv), nh(nh), nr23(nr23), kv(kv), nb(nb), mask(mask), sinks(sinks), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
const int64_t hsk_padded = GGML_PAD(hsk, ggml_blck_size(type_KV));
ggml_set_name(m, "m");
}
+ ggml_tensor * s = nullptr;
+ if (sinks) {
+ s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, q->ne[2]);
+ ggml_set_name(s, "s");
+ }
+
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hsk), max_bias, logit_softcap);
- ggml_flash_attn_ext_set_prec(out, prec);
+ ggml_flash_attn_ext_add_sinks(out, s);
+ ggml_flash_attn_ext_set_prec (out, prec);
ggml_set_name(out, "out");
return out;
}
+ void initialize_tensors(ggml_context * ctx) override {
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+ if (strcmp(t->name, "s") == 0) {
+ // make the sink values more noticable in order to trigger a test failure when the implementation is wrong
+ init_tensor_uniform(t, -10.0f, 10.0f);
+ } else {
+ init_tensor_uniform(t);
+ }
+ }
+ }
+
bool grad_precise() override {
return true;
}
GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
GGML_TYPE_Q8_0,
+ GGML_TYPE_MXFP4,
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
GGML_TYPE_Q6_K,
GGML_TYPE_Q4_0,
GGML_TYPE_Q4_1, // for I8MM tests
GGML_TYPE_Q4_K,
+ GGML_TYPE_MXFP4, // TODO: or "other"
GGML_TYPE_IQ2_XXS
};
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
for (int v : {0, 1}) {
for (int op = 0; op < GGML_GLU_OP_COUNT; op++) {
+ if (op == GGML_GLU_OP_SWIGLU_OAI) {
+ // SWIGLU_OAI is handled separately
+ continue;
+ }
+
for (bool swapped : {false, true}) {
test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 128, 2, 2, 2 }, v, swapped));
test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 5, 7, 11, 13 }, v, swapped));
}
}
+ for (int v : {0, 1}) {
+ for (float alpha : {.5f, 1.702f}) {
+ for (float limit : {2.0f, 7.0f}) {
+ test_cases.emplace_back(new test_swiglu_oai(GGML_TYPE_F32, { 128, 2, 2, 2 }, v, alpha, limit));
+ }
+ }
+ }
+
test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, false));
for (ggml_type type : all_types) {
for (int b : {1, 7}) {
}
}
+ // add_id
+ for (ggml_type type_a : {GGML_TYPE_F32}) {
+ for (ggml_type type_b : {GGML_TYPE_F32}) {
+ for (int n_mats : {4, 8}) {
+ for (int n_used : {1, 2, 4}) {
+ for (int n_embd : {32, 129}) {
+ for (int n_token : {1, 32, 129}) {
+ test_cases.emplace_back(new test_add_id(type_a, type_b, n_embd, n_mats, n_used, n_token));
+ }
+ }
+ }
+ }
+ }
+ }
+
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
test_cases.emplace_back(new test_sqr(type));
test_cases.emplace_back(new test_sqrt(type));
}
#endif
for (bool mask : {false, true}) {
- for (float max_bias : {0.0f, 8.0f}) {
- if (!mask && max_bias > 0.0f) continue;
- for (float scale : {1.0f, 0.1f}) {
- for (int64_t ne0 : {16, 1024}) {
- for (int64_t ne1 : {16, 1024}) {
- if (mask) {
- for (ggml_type m_prec : {GGML_TYPE_F32, GGML_TYPE_F16}) {
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, m_prec, {1, 1}, scale, max_bias));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {1, 1}, scale, max_bias));
-
- if (ne0 <= 32 && ne1 <= 32) {
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 3}, mask, m_prec, {3, 1}, scale, max_bias));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {2, 3}, scale, max_bias));
+ for (bool sinks : {false, true}) {
+ for (float max_bias : {0.0f, 8.0f}) {
+ if (!mask && max_bias > 0.0f) continue;
+ for (float scale : {1.0f, 0.1f}) {
+ for (int64_t ne0 : {16, 1024}) {
+ for (int64_t ne1 : {16, 1024}) {
+ if (mask) {
+ for (ggml_type m_prec : {GGML_TYPE_F32, GGML_TYPE_F16}) {
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, sinks, m_prec, {1, 1}, scale, max_bias));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, sinks, m_prec, {1, 1}, scale, max_bias));
+
+ if (ne0 <= 32 && ne1 <= 32) {
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 3}, mask, sinks, m_prec, {3, 1}, scale, max_bias));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, sinks, m_prec, {2, 3}, scale, max_bias));
+ }
}
+ } else {
+ /* The precision of mask here doesn't matter as boolean mask is false */
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, sinks, GGML_TYPE_F32, {1, 1}, scale, max_bias));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, sinks, GGML_TYPE_F32, {1, 1}, scale, max_bias));
}
- } else {
- /* The precision of mask here doesn't matter as boolean mask is false */
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, GGML_TYPE_F32, {1, 1}, scale, max_bias));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, GGML_TYPE_F32, {1, 1}, scale, max_bias));
}
}
}
}
}
}
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, true, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, false, GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, true, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, true, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, false, GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, true, GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, true, GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f));
for (float max_bias : {0.0f, 8.0f}) {
for (float scale : {1.0f, 0.1f}) {
if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA
for (bool mask : { true, false } ) {
- for (float max_bias : { 0.0f, 8.0f }) {
- if (!mask && max_bias > 0.0f) continue;
- for (float logit_softcap : {0.0f, 10.0f}) {
- if (hsk != 128 && logit_softcap != 0.0f) continue;
- for (int nh : { 4, }) {
- for (int nr3 : { 1, 3, }) {
- if (hsk > 64 && nr3 > 1) continue; // skip broadcast for large head sizes
- for (int nr2 : { 1, 4, 16 }) {
- if (nr2 == 16 && hsk != 128) continue;
- for (int kv : { 512, 1024, }) {
- if (nr2 != 1 && kv != 512) continue;
- for (int nb : { 1, 3, 32, 35, }) {
- for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
- if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
- for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
- test_cases.emplace_back(new test_flash_attn_ext(
- hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, max_bias, logit_softcap, prec, type_KV));
- // run fewer test cases permuted
- if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
+ for (bool sinks : { true, false } ) {
+ for (float max_bias : { 0.0f, 8.0f }) {
+ if (!mask && max_bias > 0.0f) continue;
+ for (float logit_softcap : {0.0f, 10.0f}) {
+ if (hsk != 128 && logit_softcap != 0.0f) continue;
+ for (int nh : { 4, }) {
+ for (int nr3 : { 1, 3, }) {
+ if (hsk > 64 && nr3 > 1) continue; // skip broadcast for large head sizes
+ for (int nr2 : { 1, 4, 16 }) {
+ if (nr2 == 16 && hsk != 128) continue;
+ for (int kv : { 512, 1024, }) {
+ if (nr2 != 1 && kv != 512) continue;
+ for (int nb : { 1, 3, 32, 35, }) {
+ for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
+ if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
+ for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
test_cases.emplace_back(new test_flash_attn_ext(
- hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
+ hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV));
+ // run fewer test cases permuted
+ if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
+ test_cases.emplace_back(new test_flash_attn_ext(
+ hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
+ }
}
}
}
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3}));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 256, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
- test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 256, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
+ test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 10, 1, 1}));
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
for (int kv : { 4096, 8192, 16384, }) {
for (int hs : { 64, 128, }) {
for (int nr : { 1, 4, }) {
- test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, {nr, 1}, kv, 1, true, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
+ test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, {nr, 1}, kv, 1, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
}
}
}
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, {256, 256, 3, 1}));
+
+ for (int n_token : {1, 512}) {
+ test_cases.emplace_back(new test_add_id(GGML_TYPE_F32, GGML_TYPE_F32, 2880, 128, 4, n_token));
+ test_cases.emplace_back(new test_add_id(GGML_TYPE_F32, GGML_TYPE_F32, 2880, 32, 4, n_token));
+ }
+
return test_cases;
}
static const std::vector<quant_option> QUANT_OPTIONS = {
{ "Q4_0", LLAMA_FTYPE_MOSTLY_Q4_0, " 4.34G, +0.4685 ppl @ Llama-3-8B", },
{ "Q4_1", LLAMA_FTYPE_MOSTLY_Q4_1, " 4.78G, +0.4511 ppl @ Llama-3-8B", },
+ { "MXFP4_MOE",LLAMA_FTYPE_MOSTLY_MXFP4_MOE," MXFP4 MoE", },
{ "Q5_0", LLAMA_FTYPE_MOSTLY_Q5_0, " 5.21G, +0.1316 ppl @ Llama-3-8B", },
{ "Q5_1", LLAMA_FTYPE_MOSTLY_Q5_1, " 5.65G, +0.1062 ppl @ Llama-3-8B", },
{ "IQ2_XXS", LLAMA_FTYPE_MOSTLY_IQ2_XXS, " 2.06 bpw quantization", },
if (msg.content === null || msg.role !== 'assistant') {
return { content: msg.content };
}
+ const REGEX_THINK_OPEN = /<think>|<\|channel\|>analysis<\|message\|>/;
+ const REGEX_THINK_CLOSE =
+ /<\/think>|<\|start\|>assistant<\|channel\|>final<\|message\|>/;
let actualContent = '';
let thought = '';
let isThinking = false;
- let thinkSplit = msg.content.split('<think>', 2);
+ let thinkSplit = msg.content.split(REGEX_THINK_OPEN, 2);
actualContent += thinkSplit[0];
while (thinkSplit[1] !== undefined) {
// <think> tag found
- thinkSplit = thinkSplit[1].split('</think>', 2);
+ thinkSplit = thinkSplit[1].split(REGEX_THINK_CLOSE, 2);
thought += thinkSplit[0];
isThinking = true;
if (thinkSplit[1] !== undefined) {
// </think> closing tag found
isThinking = false;
- thinkSplit = thinkSplit[1].split('<think>', 2);
+ thinkSplit = thinkSplit[1].split(REGEX_THINK_OPEN, 2);
actualContent += thinkSplit[0];
}
}