if chkhsh == "169bf0296a13c4d9b7672313f749eb36501d931022de052aad6e36f2bf34dd51":
# ref: https://huggingface.co/LiquidAI/LFM2-Tokenizer
res = "lfm2"
+ if chkhsh == "2085e1638f6c377a0aa4ead21b27bb4cb941bf800df86ed391011769c1758dfb":
+ # ref: https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B
+ res = "exaone4"
if res is None:
logger.warning("\n")
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32))
+@ModelBase.register("Exaone4ForCausalLM")
+class Exaone4Model(TextModel):
+ model_arch = gguf.MODEL_ARCH.EXAONE4
+
+ def set_vocab(self):
+ tokens, toktypes, tokpre = self.get_vocab_base()
+ self.gguf_writer.add_tokenizer_model("gpt2")
+ self.gguf_writer.add_tokenizer_pre(tokpre)
+ self.gguf_writer.add_token_list(tokens)
+ self.gguf_writer.add_token_types(toktypes)
+
+ special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
+ special_vocab.add_to_gguf(self.gguf_writer)
+
+ def set_gguf_parameters(self):
+ super().set_gguf_parameters()
+ hparams = self.hparams
+ self.gguf_writer.add_vocab_size(hparams["vocab_size"])
+
+ if hparams.get("sliding_window") is not None:
+ self.gguf_writer.add_sliding_window(hparams["sliding_window"])
+ if "layer_types" in hparams:
+ self.gguf_writer.add_sliding_window_pattern([t == "sliding_attention" for t in hparams["layer_types"]])
+ elif "sliding_window_pattern" in hparams:
+ sliding_window_pattern = []
+ if isinstance(hparams["sliding_window_pattern"], str): # e.g. LLLG
+ for i in range(hparams["num_hidden_layers"]):
+ sliding_window_pattern.append(hparams["sliding_window_pattern"][i % len(hparams["sliding_window_pattern"])] == "L")
+ if isinstance(hparams["sliding_window_pattern"], int): # e.g. 4
+ for i in range(hparams["num_hidden_layers"]):
+ sliding_window_pattern.append((i + 1) % hparams["sliding_window_pattern"] != 0)
+ if len(sliding_window_pattern) == hparams["num_hidden_layers"]:
+ self.gguf_writer.add_sliding_window_pattern(sliding_window_pattern)
+
+ rope_scaling = self.hparams.get("rope_scaling") or {}
+ if rope_scaling.get("rope_type", rope_scaling.get("type")) == "linear" and "factor" in rope_scaling:
+ self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
+ self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
+
+ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
+ if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
+ if rope_scaling.get("rope_type", '').lower() == "llama3":
+ base = self.hparams.get("rope_theta", 10_000.0)
+ if (dim := self.hparams.get("head_dim")) is None:
+ dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
+ freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
+
+ factor = rope_scaling.get("factor", 16.0)
+ low_freq_factor = rope_scaling.get("low_freq_factor", 1.0)
+ high_freq_factor = rope_scaling.get("high_freq_factor", 4.0)
+ old_context_len = self.hparams.get("original_max_position_embeddings", 8192)
+
+ low_freq_wavelen = old_context_len / low_freq_factor
+ high_freq_wavelen = old_context_len / high_freq_factor
+
+ rope_factors = []
+ for freq in freqs:
+ wavelen = 2 * math.pi / freq
+ if wavelen < high_freq_wavelen:
+ rope_factors.append(1)
+ elif wavelen > low_freq_wavelen:
+ rope_factors.append(factor)
+ else:
+ smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
+ rope_factors.append(1 / ((1 - smooth) / factor + smooth))
+
+ yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32))
+
+
@ModelBase.register("GraniteForCausalLM")
class GraniteModel(LlamaModel):
"""Conversion for IBM's GraniteForCausalLM"""
{"name": "a.x-4.0", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/skt/A.X-4.0", },
{"name": "midm-2.0", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/K-intelligence/Midm-2.0-Base-Instruct", },
{"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"},
+ {"name": "exaone4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", },
]
# some models are known to be broken upstream, so we will skip them as exceptions
JAIS = auto()
NEMOTRON = auto()
EXAONE = auto()
+ EXAONE4 = auto()
GRANITE = auto()
GRANITE_MOE = auto()
GRANITE_HYBRID = auto()
MODEL_ARCH.JAIS: "jais",
MODEL_ARCH.NEMOTRON: "nemotron",
MODEL_ARCH.EXAONE: "exaone",
+ MODEL_ARCH.EXAONE4: "exaone4",
MODEL_ARCH.GRANITE: "granite",
MODEL_ARCH.GRANITE_MOE: "granitemoe",
MODEL_ARCH.GRANITE_HYBRID: "granitehybrid",
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
+ MODEL_ARCH.EXAONE4: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_Q_NORM,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_K_NORM,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.ATTN_POST_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ MODEL_TENSOR.FFN_POST_NORM,
+ ],
MODEL_ARCH.GRANITE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
{ LLM_ARCH_JAIS, "jais" },
{ LLM_ARCH_NEMOTRON, "nemotron" },
{ LLM_ARCH_EXAONE, "exaone" },
+ { LLM_ARCH_EXAONE4, "exaone4" },
{ LLM_ARCH_RWKV6, "rwkv6" },
{ LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" },
{ LLM_ARCH_RWKV7, "rwkv7" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
+ {
+ LLM_ARCH_EXAONE4,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_OUTPUT, "output" },
+ { LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
+ { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
+ { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
+ { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
+ }
+ },
{
LLM_ARCH_RWKV6,
{
LLM_ARCH_JAIS,
LLM_ARCH_NEMOTRON,
LLM_ARCH_EXAONE,
+ LLM_ARCH_EXAONE4,
LLM_ARCH_RWKV6,
LLM_ARCH_RWKV6QWEN2,
LLM_ARCH_RWKV7,
{ "glmedge", LLM_CHAT_TEMPLATE_GLMEDGE },
{ "minicpm", LLM_CHAT_TEMPLATE_MINICPM },
{ "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 },
+ { "exaone4", LLM_CHAT_TEMPLATE_EXAONE_4 },
{ "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD },
{ "granite", LLM_CHAT_TEMPLATE_GRANITE },
{ "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT },
} else if (tmpl_contains(LU8("<|Assistant|>")) && tmpl_contains(LU8("<|User|>")) && tmpl_contains(LU8("<|end▁of▁sentence|>"))) {
return LLM_CHAT_TEMPLATE_DEEPSEEK_3;
} else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) {
+ if (tmpl_contains("[|tool|]")) {
+ return LLM_CHAT_TEMPLATE_EXAONE_4;
+ }
// ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb
// EXAONE-3.0-7.8B-Instruct
return LLM_CHAT_TEMPLATE_EXAONE_3;
if (add_ass) {
ss << "[|assistant|]";
}
+ } else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_4) {
+ for (auto message : chat) {
+ std::string role(message->role);
+ if (role == "system") {
+ ss << "[|system|]" << trim(message->content) << "[|endofturn|]\n";
+ } else if (role == "user") {
+ ss << "[|user|]" << trim(message->content) << "\n";
+ } else if (role == "assistant") {
+ ss << "[|assistant|]" << trim(message->content) << "[|endofturn|]\n";
+ } else if (role == "tool") {
+ ss << "[|tool|]" << trim(message->content) << "[|endofturn|]\n";
+ }
+ }
+ if (add_ass) {
+ ss << "[|assistant|]";
+ }
} else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) {
// this template requires the model to have "\n\n" as EOT token
for (size_t i = 0; i < chat.size(); i++) {
LLM_CHAT_TEMPLATE_GLMEDGE,
LLM_CHAT_TEMPLATE_MINICPM,
LLM_CHAT_TEMPLATE_EXAONE_3,
+ LLM_CHAT_TEMPLATE_EXAONE_4,
LLM_CHAT_TEMPLATE_RWKV_WORLD,
LLM_CHAT_TEMPLATE_GRANITE,
LLM_CHAT_TEMPLATE_GIGACHAT,
default: type = LLM_TYPE_UNKNOWN;
}
} break;
+ case LLM_ARCH_EXAONE4:
+ {
+ if (hparams.n_layer == 64) { // 32B
+ hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
+ hparams.n_swa = 4096;
+ hparams.set_swa_pattern(4);
+ }
+
+ ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+ switch (hparams.n_layer) {
+ case 30: type = LLM_TYPE_1_2B; break;
+ case 64: type = LLM_TYPE_32B; break;
+ default: type = LLM_TYPE_UNKNOWN;
+ }
+ } break;
case LLM_ARCH_RWKV6:
case LLM_ARCH_RWKV6QWEN2:
{
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}
} break;
+ case LLM_ARCH_EXAONE4:
+ {
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+ // output
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+
+ // if output is NULL, init from the input tok embed
+ if (output == NULL) {
+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+ }
+
+ for (int i = 0; i < n_layer; ++i) {
+ auto & layer = layers[i];
+
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+
+ layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
+
+ layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0);
+ layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
+ layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
+
+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
+ layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
+ }
+ } break;
case LLM_ARCH_RWKV6:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
}
};
+template <bool iswa>
+struct llm_build_exaone4 : public llm_graph_context {
+ llm_build_exaone4(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+ const int64_t n_embd_head = hparams.n_embd_head_k;
+
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_v);
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+ ggml_tensor * cur;
+ ggml_tensor * inpL;
+
+ inpL = build_inp_embd(model.tok_embd);
+
+ // inp_pos - contains the positions
+ ggml_tensor * inp_pos = build_inp_pos();
+
+ using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_unified_iswa, llm_graph_input_attn_kv_unified>;
+ inp_attn_type * inp_attn = nullptr;
+
+ if constexpr (iswa) {
+ inp_attn = build_attn_inp_kv_unified_iswa();
+ } else {
+ inp_attn = build_attn_inp_kv_unified();
+ }
+
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+ for (int il = 0; il < n_layer; ++il) {
+ ggml_tensor * inpSA = inpL;
+
+ // use RoPE for SWA layers or non-SWA models
+ const bool use_rope = hparams.is_swa(il) || hparams.swa_type == LLAMA_SWA_TYPE_NONE;
+
+ cur = inpL;
+
+ // self-attention
+ {
+ ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
+
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+ cb(Qcur, "Qcur", il);
+
+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+ cb(Kcur, "Kcur", il);
+
+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+ cb(Vcur, "Vcur", il);
+
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+ Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
+ Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
+ cb(Qcur, "Qcur_normed", il);
+ cb(Kcur, "Kcur_normed", il);
+
+ if (use_rope) {
+ Qcur = ggml_rope_ext(
+ ctx0, Qcur, inp_pos, rope_factors,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+
+ Kcur = ggml_rope_ext(
+ ctx0, Kcur, inp_pos, rope_factors,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+ }
+
+ cb(Qcur, "Qcur", il);
+ cb(Kcur, "Kcur", il);
+ cb(Vcur, "Vcur", il);
+
+ cur = build_attn(inp_attn, gf,
+ model.layers[il].wo, NULL,
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+ cb(cur, "attn_out", il);
+ }
+
+ if (il == n_layer - 1 && inp_out_ids) {
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
+ cur = build_norm(cur,
+ model.layers[il].attn_post_norm, NULL,
+ LLM_NORM_RMS, il);
+ cb(cur, "attn_post_norm", il);
+
+ ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+ cb(ffn_inp, "ffn_inp", il);
+
+ // feed-forward network
+ cur = build_ffn(ffn_inp,
+ model.layers[il].ffn_up, NULL, NULL,
+ model.layers[il].ffn_gate, NULL, NULL,
+ model.layers[il].ffn_down, NULL, NULL,
+ NULL,
+ LLM_FFN_SILU, LLM_FFN_PAR, il);
+ cb(cur, "ffn_out", il);
+
+ cur = build_norm(cur,
+ model.layers[il].ffn_post_norm, NULL,
+ LLM_NORM_RMS, -1);
+ cb(cur, "ffn_post_norm", -1);
+
+ cur = ggml_add(ctx0, cur, ffn_inp);
+
+ cur = build_cvec(cur, il);
+ cb(cur, "l_out", il);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ cur = inpL;
+
+ cur = build_norm(cur,
+ model.output_norm, NULL,
+ LLM_NORM_RMS, -1);
+
+ cb(cur, "result_norm", -1);
+ res->t_embd = cur;
+
+ // lm_head
+ cur = build_lora_mm(model.output, cur);
+
+ cb(cur, "result_output", -1);
+ res->t_logits = cur;
+
+ ggml_build_forward_expand(gf, cur);
+ }
+};
+
struct llm_build_rwkv6_base : public llm_graph_context {
const llama_model & model;
{
llm = std::make_unique<llm_build_exaone>(*this, params);
} break;
+ case LLM_ARCH_EXAONE4:
+ {
+ if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) {
+ llm = std::make_unique<llm_build_exaone4<true>>(*this, params, gf);
+ } else {
+ llm = std::make_unique<llm_build_exaone4<false>>(*this, params, gf);
+ }
+ } break;
case LLM_ARCH_RWKV6:
{
llm = std::make_unique<llm_build_rwkv6>(*this, params);
case LLM_ARCH_ORION:
case LLM_ARCH_NEMOTRON:
case LLM_ARCH_EXAONE:
+ case LLM_ARCH_EXAONE4:
case LLM_ARCH_MINICPM3:
case LLM_ARCH_DOTS1:
case LLM_ARCH_HUNYUAN_MOE:
} else if (
tokenizer_pre == "exaone") {
pre_type = LLAMA_VOCAB_PRE_TYPE_EXAONE;
+ } else if (
+ tokenizer_pre == "exaone4") {
+ pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2;
} else if (
tokenizer_pre == "chameleon") {
pre_type = LLAMA_VOCAB_PRE_TYPE_CHAMELEON;