if chkhsh == "1994ffd01900cfb37395608534236ecd63f2bd5995d6cb1004dda1af50240f15":
# ref: https://huggingface.co/trillionlabs/Trillion-7B-preview
res = "trillion"
+ if chkhsh == "96a5f08be6259352137b512d4157e333e21df7edd3fcd152990608735a65b224":
+ # ref: https://huggingface.co/inclusionAI/Ling-lite
+ res = "bailingmoe"
if res is None:
logger.warning("\n")
return super().modify_tensors(data_torch, name, bid)
+@Model.register("BailingMoeForCausalLM")
+class BailingMoeModel(Model):
+ model_arch = gguf.MODEL_ARCH.BAILINGMOE
+
+ def set_vocab(self):
+ self._set_vocab_gpt2()
+
+ def set_gguf_parameters(self):
+ super().set_gguf_parameters()
+ hparams = self.hparams
+ if "head_dim" in hparams:
+ rope_dim = hparams["head_dim"]
+ else:
+ rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
+
+ self.gguf_writer.add_rope_dimension_count(rope_dim)
+ self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
+ self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
+ self.gguf_writer.add_vocab_size(hparams["vocab_size"])
+ self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
+ self.gguf_writer.add_expert_weights_scale(1.0)
+ self.gguf_writer.add_expert_count(hparams["num_experts"])
+ self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"])
+ self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
+
+ _experts: list[dict[str, Tensor]] | None = None
+
+ @staticmethod
+ def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
+ if n_head_kv is not None and n_head != n_head_kv:
+ n_head = n_head_kv
+ return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
+ .swapaxes(1, 2)
+ .reshape(weights.shape))
+
+ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+ n_head = self.hparams["num_attention_heads"]
+ n_kv_head = self.hparams.get("num_key_value_heads")
+ n_embd = self.hparams["hidden_size"]
+ head_dim = self.hparams.get("head_dim", n_embd // n_head)
+
+ output_name = self.format_tensor_name(gguf.MODEL_TENSOR.OUTPUT)
+
+ if name.endswith("attention.dense.weight"):
+ return [(self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT, bid), data_torch)]
+ elif name.endswith("query_key_value.weight"):
+ q, k, v = data_torch.split([n_head * head_dim, n_kv_head * head_dim, n_kv_head * head_dim], dim=-2)
+
+ return [
+ (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), BailingMoeModel.permute(q, n_head, n_head)),
+ (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), BailingMoeModel.permute(k, n_head, n_kv_head)),
+ (self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), v)
+ ]
+ elif name.find("mlp.experts") != -1:
+ n_experts = self.hparams["num_experts"]
+ assert bid is not None
+
+ tensors: list[tuple[str, Tensor]] = []
+
+ if self._experts is None:
+ self._experts = [{} for _ in range(self.block_count)]
+
+ self._experts[bid][name] = data_torch
+
+ if len(self._experts[bid]) >= n_experts * 3:
+ # merge the experts into a single 3d tensor
+ for w_name in ["down_proj", "gate_proj", "up_proj"]:
+ datas: list[Tensor] = []
+
+ for xid in range(n_experts):
+ ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
+ datas.append(self._experts[bid][ename])
+ del self._experts[bid][ename]
+
+ data_torch = torch.stack(datas, dim=0)
+
+ merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
+
+ new_name = self.map_tensor_name(merged_name)
+
+ tensors.append((new_name, data_torch))
+
+ return tensors
+
+ new_name = self.map_tensor_name(name)
+
+ if new_name == output_name and self.hparams.get("norm_head"):
+ data_torch = data_torch.float()
+ data_torch /= torch.norm(data_torch, p=2, dim=0, keepdim=True) + 1e-7
+
+ return [(new_name, data_torch)]
+
+ def prepare_tensors(self):
+ super().prepare_tensors()
+
+ if self._experts is not None:
+ # flatten `list[dict[str, Tensor]]` into `list[str]`
+ experts = [k for d in self._experts for k in d.keys()]
+ if len(experts) > 0:
+ raise ValueError(f"Unprocessed experts: {experts}")
+
+
@Model.register("ChameleonForConditionalGeneration")
@Model.register("ChameleonForCausalLM") # obsolete
class ChameleonModel(Model):
case LLM_TYPE_10B_128x3_66B: return "10B+128x3.66B";
case LLM_TYPE_57B_A14B: return "57B.A14B";
case LLM_TYPE_27B: return "27B";
+ case LLM_TYPE_290B: return "290B";
default: return "?B";
}
}
ml.get_key(LLM_KV_ATTENTION_GROUPNORM_GROUPS, hparams.n_norm_groups);
ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
} break;
+ case LLM_ARCH_BAILINGMOE:
+ {
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+ ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
+ ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
+ ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
+ ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale);
+ ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false);
+
+ switch (hparams.n_layer) {
+ case 28: type = LLM_TYPE_16B; break;
+ case 88: type = LLM_TYPE_290B; break;
+ default: type = LLM_TYPE_UNKNOWN;
+ }
+ } break;
default: throw std::runtime_error("unsupported model architecture");
}
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0);
output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_embd}, 0);
} break;
+ case LLM_ARCH_BAILINGMOE:
+ {
+ const int64_t n_ff_exp = hparams.n_ff_exp;
+ const int64_t n_expert_shared = hparams.n_expert_shared;
+
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+ // output
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0);
+
+ for (int i = 0; i < n_layer; ++i) {
+ auto & layer = layers[i];
+
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_head * n_rot}, 0);
+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_head_kv * n_rot}, 0);
+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_head_kv * n_rot}, 0);
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_rot, n_embd}, 0);
+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+ layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
+
+ if (n_expert == 0) {
+ throw std::runtime_error("n_expert must be > 0");
+ }
+ if (n_expert_used == 0) {
+ throw std::runtime_error("n_expert_used must be > 0");
+ }
+
+ layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
+ layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
+ layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
+
+ layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
+ layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, 0);
+ layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
+ }
+ } break;
default:
throw std::runtime_error("unknown architecture");
}
LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale);
}
+ if (arch == LLM_ARCH_BAILINGMOE) {
+ LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
+ LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
+ LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared);
+ LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
+ LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm);
+ }
+
vocab.print_info();
}
}
};
+struct llm_build_bailingmoe : public llm_graph_context {
+ llm_build_bailingmoe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+ ggml_tensor * cur;
+ ggml_tensor * inpL;
+
+ inpL = build_inp_embd(model.tok_embd);
+
+ // inp_pos - contains the positions
+ ggml_tensor * inp_pos = build_inp_pos();
+
+ auto * inp_attn = build_attn_inp_kv_unified();
+
+ for (int il = 0; il < n_layer; ++il) {
+ ggml_tensor * inpSA = inpL;
+
+ // norm
+ cur = build_norm(inpL,
+ model.layers[il].attn_norm, NULL,
+ LLM_NORM_RMS, il);
+ cb(cur, "attn_norm", il);
+
+ // self-attention
+ {
+ // rope freq factors for llama3; may return nullptr for llama2 and other models
+ ggml_tensor * rope_factors = static_cast<const llama_kv_cache_unified *>(memory)->cbs.get_rope_factors(n_ctx_per_seq, il);
+
+ // compute Q and K and RoPE them
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+ cb(Qcur, "Qcur", il);
+ if (model.layers[il].bq) {
+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+ cb(Qcur, "Qcur", il);
+ }
+
+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+ cb(Kcur, "Kcur", il);
+ if (model.layers[il].bk) {
+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+ cb(Kcur, "Kcur", il);
+ }
+
+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+ cb(Vcur, "Vcur", il);
+ if (model.layers[il].bv) {
+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+ cb(Vcur, "Vcur", il);
+ }
+
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens);
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens);
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_rot, n_head_kv, n_tokens);
+
+ Qcur = ggml_rope_ext(
+ ctx0, Qcur, inp_pos, rope_factors,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+
+ Kcur = ggml_rope_ext(
+ ctx0, Kcur, inp_pos, rope_factors,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+
+ cb(Qcur, "Qcur", il);
+ cb(Kcur, "Kcur", il);
+ cb(Vcur, "Vcur", il);
+
+ cur = build_attn(inp_attn, gf,
+ model.layers[il].wo, model.layers[il].bo,
+ Qcur, Kcur, Vcur, nullptr, 1.0f/sqrtf(float(n_rot)), il);
+ }
+
+ if (il == n_layer - 1) {
+ // skip computing output for unused tokens
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
+ ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+ cb(ffn_inp, "ffn_inp", il);
+
+ cur = build_norm(ffn_inp,
+ model.layers[il].ffn_norm, NULL,
+ LLM_NORM_RMS, il);
+ cb(cur, "ffn_norm", il);
+
+ ggml_tensor * moe_out =
+ build_moe_ffn(cur,
+ model.layers[il].ffn_gate_inp,
+ model.layers[il].ffn_up_exps,
+ model.layers[il].ffn_gate_exps,
+ model.layers[il].ffn_down_exps,
+ nullptr,
+ n_expert, n_expert_used,
+ LLM_FFN_SILU, hparams.expert_weights_norm,
+ false, hparams.expert_weights_scale,
+ LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
+ il);
+ cb(moe_out, "ffn_moe_out", il);
+
+ // FFN shared expert
+ {
+ ggml_tensor * ffn_shexp = build_ffn(cur,
+ model.layers[il].ffn_up_shexp, NULL, NULL,
+ model.layers[il].ffn_gate_shexp, NULL, NULL,
+ model.layers[il].ffn_down_shexp, NULL, NULL,
+ NULL,
+ LLM_FFN_SILU, LLM_FFN_PAR, il);
+ cb(ffn_shexp, "ffn_shexp", il);
+
+ cur = ggml_add(ctx0, moe_out, ffn_shexp);
+ cb(cur, "ffn_out", il);
+ }
+
+ cur = ggml_add(ctx0, cur, ffn_inp);
+
+ cur = build_cvec(cur, il);
+ cb(cur, "l_out", il);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ cur = inpL;
+
+ cur = build_norm(cur,
+ model.output_norm, NULL,
+ LLM_NORM_RMS, -1);
+
+ cb(cur, "result_norm", -1);
+ res->t_embd = cur;
+
+ // lm_head
+ cur = build_lora_mm(model.output, cur);
+
+ cb(cur, "result_output", -1);
+ res->t_logits = cur;
+
+ ggml_build_forward_expand(gf, cur);
+ }
+};
+
llama_memory_i * llama_model::create_memory() const {
llama_memory_i * res;
{
llm = std::make_unique<llm_build_plm>(*this, params, gf);
} break;
+ case LLM_ARCH_BAILINGMOE:
+ {
+ llm = std::make_unique<llm_build_bailingmoe>(*this, params, gf);
+ } break;
default:
GGML_ABORT("fatal error");
}
case LLM_ARCH_GRANITE:
case LLM_ARCH_GRANITE_MOE:
case LLM_ARCH_CHAMELEON:
+ case LLM_ARCH_BAILINGMOE:
return LLAMA_ROPE_TYPE_NORM;
// the pairs of head values are offset by n_rot/2