if chkhsh == "a1e163ecab2e718a4c829d1148b6e86824ec36163bb71941c3dca9cd5ac25756":
# ref: https://huggingface.co/JetBrains/Mellum-4b-base
res = "mellum"
+ if chkhsh == "9b1be57e70d20d9501b2b3186e792d81181ae36ada3903c26f9fea418cf87206":
+ # ref: https://huggingface.co/inclusionAI/LLaDA-MoE-7B-A1B-Base
+ res = "llada-moe"
if res is None:
logger.warning("\n")
raise ValueError(f"Unprocessed experts: {experts}")
+@ModelBase.register("LLaDAMoEModel", "LLaDAMoEModelLM")
+class LLaDAMoEModel(TextModel):
+ model_arch = gguf.MODEL_ARCH.LLADA_MOE
+
+ def set_gguf_parameters(self):
+ super().set_gguf_parameters()
+ if (n_experts := self.hparams.get("num_experts")) is not None:
+ self.gguf_writer.add_expert_count(n_experts)
+
+ if (expert_intermediate_size := self.hparams.get("expert_intermediate_size")) is not None:
+ self.gguf_writer.add_expert_feed_forward_length(expert_intermediate_size)
+
+ # number of experts used per token (top-k)
+ if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
+ self.gguf_writer.add_expert_used_count(n_experts_used)
+
+ self.gguf_writer.add_mask_token_id(156895)
+ self.gguf_writer.add_causal_attention(False)
+ self.gguf_writer.add_diffusion_shift_logits(False)
+
+ _experts: list[dict[str, Tensor]] | None = None
+
+ # Copied from: Qwen2MoeModel
+ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+ # process the experts separately
+ if name.find("experts") != -1:
+ n_experts = self.hparams["num_experts"]
+ assert bid is not None
+
+ if self._experts is None:
+ self._experts = [{} for _ in range(self.block_count)]
+
+ self._experts[bid][name] = data_torch
+
+ if len(self._experts[bid]) >= n_experts * 3:
+ tensors: list[tuple[str, Tensor]] = []
+
+ # merge the experts into a single 3d tensor
+ for w_name in ["down_proj", "gate_proj", "up_proj"]:
+ datas: list[Tensor] = []
+
+ for xid in range(n_experts):
+ ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
+ datas.append(self._experts[bid][ename])
+ del self._experts[bid][ename]
+
+ data_torch = torch.stack(datas, dim=0)
+
+ merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
+
+ new_name = self.map_tensor_name(merged_name)
+
+ tensors.append((new_name, data_torch))
+ return tensors
+ else:
+ return []
+
+ return [(self.map_tensor_name(name), data_torch)]
+
+ # Copied from: Qwen2MoeModel
+ def prepare_tensors(self):
+ super().prepare_tensors()
+
+ if self._experts is not None:
+ # flatten `list[dict[str, Tensor]]` into `list[str]`
+ experts = [k for d in self._experts for k in d.keys()]
+ if len(experts) > 0:
+ raise ValueError(f"Unprocessed experts: {experts}")
+
+
@ModelBase.register("HunYuanDenseV1ForCausalLM")
class HunYuanModel(TextModel):
model_arch = gguf.MODEL_ARCH.HUNYUAN_DENSE
n_generated = params.max_length;
}
-static std::string format_input_text(const std::string & prompt, bool use_chat_template, llama_model * model) {
+static std::string format_input_text(const std::string & prompt, const std::string & system_prompt, bool use_chat_template, llama_model * model) {
if (!use_chat_template) {
return prompt;
}
auto chat_templates = common_chat_templates_init(model, "");
-
common_chat_templates_inputs inputs;
- common_chat_msg user_msg;
- user_msg.role = "user";
- user_msg.content = prompt;
- inputs.add_generation_prompt = true;
+ common_chat_msg system_msg;
+
+ if (!system_prompt.empty()) {
+ system_msg.role = "system";
+ system_msg.content = system_prompt;
+ inputs.messages.push_back(system_msg);
+ }
+
+ common_chat_msg user_msg;
+ user_msg.role = "user";
+ user_msg.content = prompt;
+
inputs.messages.push_back(user_msg);
+ inputs.add_generation_prompt = true;
auto result = common_chat_templates_apply(chat_templates.get(), inputs);
llama_set_n_threads(ctx, params.cpuparams.n_threads, params.cpuparams_batch.n_threads);
const llama_vocab * vocab = llama_model_get_vocab(model);
- std::string formatted_prompt = format_input_text(params.prompt, params.enable_chat_template, model);
+
+ std::string formatted_prompt = format_input_text(params.prompt, params.system_prompt, params.enable_chat_template, model);
std::vector<llama_token> input_tokens = common_tokenize(vocab,
formatted_prompt,
}
llama_token mask_token_id = llama_vocab_mask(vocab);
+
GGML_ASSERT(mask_token_id != LLAMA_TOKEN_NULL);
bool visual_mode = params.diffusion.visual_mode;
{ LLM_ARCH_DREAM, "dream" },
{ LLM_ARCH_SMALLTHINKER, "smallthinker" },
{ LLM_ARCH_LLADA, "llada" },
+ { LLM_ARCH_LLADA_MOE, "llada-moe" },
{ LLM_ARCH_SEED_OSS, "seed_oss" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
+ {
+ LLM_ARCH_LLADA_MOE,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_OUTPUT, "output" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
+ { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
+ { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
+ { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
+ { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
+ { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
+ { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
+ },
+ },
{
LLM_ARCH_SEED_OSS,
{
switch (arch) {
case LLM_ARCH_DREAM:
case LLM_ARCH_LLADA:
+ case LLM_ARCH_LLADA_MOE:
return true;
default:
return false;
hparams.causal_attn = false;
}
break;
+ case LLM_ARCH_LLADA_MOE:
+ {
+ ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
+
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+ // diffusion language model uses non-causal attention
+ hparams.causal_attn = false;
+ switch (hparams.n_layer) {
+ case 16: type = LLM_TYPE_A1_7B; break;
+ default: type = LLM_TYPE_UNKNOWN;
+ }
+ } break;
case LLM_ARCH_QWEN2MOE:
{
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
}
}
break;
+ case LLM_ARCH_LLADA_MOE:
+ {
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+ // output
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
+
+ GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for llada-moe");
+ GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for llada-moe");
+
+ for (int i = 0; i < n_layer; ++i) {
+ auto & layer = layers[i];
+
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+ layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
+ layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
+
+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+ layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
+
+ const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
+
+ layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
+ layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
+ layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
+ }
+ } break;
case LLM_ARCH_LLAMA4:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
}
};
+struct llm_build_llada_moe : public llm_graph_context {
+ llm_build_llada_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+ const int64_t n_embd_head = hparams.n_embd_head_v;
+
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+ ggml_tensor * cur;
+ ggml_tensor * inpL;
+
+ inpL = build_inp_embd(model.tok_embd);
+
+ // inp_pos - contains the positions
+ ggml_tensor * inp_pos = build_inp_pos();
+
+ auto * inp_attn = build_attn_inp_no_cache();
+
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+ for (int il = 0; il < n_layer; ++il) {
+ ggml_tensor * inpSA = inpL;
+
+ // norm
+ cur = build_norm(inpL,
+ model.layers[il].attn_norm, NULL,
+ LLM_NORM_RMS, il);
+ cb(cur, "attn_norm", il);
+
+ // self_attention
+ {
+ // compute Q and K and RoPE them
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+ cb(Qcur, "Qcur", il);
+
+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+ cb(Kcur, "Kcur", il);
+
+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+ cb(Vcur, "Vcur", il);
+
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+ Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
+ cb(Qcur, "Qcur_normed", il);
+
+ Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
+ cb(Kcur, "Kcur_normed", il);
+
+ Qcur = ggml_rope_ext(
+ ctx0, Qcur, inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+
+ Kcur = ggml_rope_ext(
+ ctx0, Kcur, inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+
+ cb(Qcur, "Qcur", il);
+ cb(Kcur, "Kcur", il);
+ cb(Vcur, "Vcur", il);
+
+ cur = build_attn(inp_attn,
+ model.layers[il].wo, NULL,
+ Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+ }
+
+ if (il == n_layer - 1 && inp_out_ids) {
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
+ ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+ cb(ffn_inp, "ffn_inp", il);
+
+ // MoE branch
+ cur = build_norm(ffn_inp,
+ model.layers[il].ffn_norm, NULL,
+ LLM_NORM_RMS, il);
+ cb(cur, "ffn_norm", il);
+
+ cur = build_moe_ffn(cur,
+ model.layers[il].ffn_gate_inp,
+ model.layers[il].ffn_up_exps,
+ model.layers[il].ffn_gate_exps,
+ model.layers[il].ffn_down_exps,
+ nullptr,
+ n_expert, n_expert_used,
+ LLM_FFN_SILU, false,
+ false, 0.0,
+ LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
+ il);
+ cb(cur, "ffn_moe_out", il);
+
+ cur = ggml_add(ctx0, cur, ffn_inp);
+
+ cur = build_cvec(cur, il);
+ cb(cur, "l_out", il);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ cur = inpL;
+
+ cur = build_norm(cur,
+ model.output_norm, NULL,
+ LLM_NORM_RMS, -1);
+
+ cb(cur, "result_norm", -1);
+ res->t_embd = cur;
+
+ // lm_head
+ cur = build_lora_mm(model.output, cur);
+
+ cb(cur, "result_output", -1);
+ res->t_logits = cur;
+
+ ggml_build_forward_expand(gf, cur);
+ }
+};
+
struct llm_build_openelm : public llm_graph_context {
llm_build_openelm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
//case LLM_ARCH_GEMMA_EMBEDDING: // TODO: disabled until the cacheless SWA logic is fixed [TAG_NO_CACHE_ISWA]
case LLM_ARCH_DREAM:
case LLM_ARCH_LLADA:
+ case LLM_ARCH_LLADA_MOE:
{
res = nullptr;
} break;
llm = std::make_unique<llm_build_llada>(*this, params);
}
break;
+ case LLM_ARCH_LLADA_MOE:
+ {
+ llm = std::make_unique<llm_build_llada_moe>(*this, params);
+ }
+ break;
case LLM_ARCH_QWEN2VL:
{
llm = std::make_unique<llm_build_qwen2vl>(*this, params);
case LLM_ARCH_QWEN2MOE:
case LLM_ARCH_QWEN3:
case LLM_ARCH_QWEN3MOE:
+ case LLM_ARCH_LLADA_MOE:
case LLM_ARCH_OLMO2:
case LLM_ARCH_OLMOE:
case LLM_ARCH_PHI2: