return super().modify_tensors(data_torch, name, bid)
+@ModelBase.register("PanguEmbeddedForCausalLM")
+class PanguEmbeddedModel(TextModel):
+ model_arch = gguf.MODEL_ARCH.PANGU_EMBED
+
+ def set_vocab(self):
+ self._set_vocab_sentencepiece()
+
+ tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
+ if tokenizer_config_file.is_file():
+ with open(tokenizer_config_file, "r", encoding="utf-8") as f:
+ tokenizer_config_json = json.load(f)
+ if "add_prefix_space" in tokenizer_config_json:
+ self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"])
+
+ def set_gguf_parameters(self):
+ super().set_gguf_parameters()
+ hparams = self.hparams
+ self.gguf_writer.add_vocab_size(hparams["vocab_size"])
+
+ # PanguEmbedded's hparam loaded from config.json without head_dim
+ if (rope_dim := hparams.get("head_dim")) is None:
+ rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
+ self.gguf_writer.add_rope_dimension_count(rope_dim)
+
+ if hparams.get("head_dim") is None:
+ self.gguf_writer.add_key_length(rope_dim)
+ self.gguf_writer.add_value_length(rope_dim)
+
+ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+ if name == "lm_head.weight":
+ if self.hparams.get("tie_word_embeddings", False):
+ logger.info("Skipping tied output layer 'lm_head.weight'")
+ return []
+ return [(self.map_tensor_name(name), data_torch)]
+
+
@ModelBase.register("Dots1ForCausalLM")
class Dots1Model(Qwen2MoeModel):
model_arch = gguf.MODEL_ARCH.DOTS1
APERTUS = auto()
COGVLM = auto()
MINIMAXM2 = auto()
+ PANGU_EMBED = auto()
class VISION_PROJECTOR_TYPE(IntEnum):
MODEL_ARCH.APERTUS: "apertus",
MODEL_ARCH.MINIMAXM2: "minimax-m2",
MODEL_ARCH.COGVLM: "cogvlm",
+ MODEL_ARCH.PANGU_EMBED: "pangu-embedded",
}
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
MODEL_TENSOR.VISEXP_UP,
MODEL_TENSOR.VISEXP_DOWN,
],
+ MODEL_ARCH.PANGU_EMBED: [
+ MODEL_TENSOR.TOKEN_EMBD,
+ MODEL_TENSOR.OUTPUT_NORM,
+ MODEL_TENSOR.OUTPUT,
+ MODEL_TENSOR.ATTN_NORM,
+ MODEL_TENSOR.ATTN_Q,
+ MODEL_TENSOR.ATTN_K,
+ MODEL_TENSOR.ATTN_V,
+ MODEL_TENSOR.ATTN_OUT,
+ MODEL_TENSOR.FFN_NORM,
+ MODEL_TENSOR.FFN_GATE,
+ MODEL_TENSOR.FFN_DOWN,
+ MODEL_TENSOR.FFN_UP,
+ ],
# TODO
}
MODEL_ARCH.BAILINGMOE: [
MODEL_TENSOR.ROPE_FREQS,
],
+ MODEL_ARCH.PANGU_EMBED: [
+ MODEL_TENSOR.ROPE_FREQS,
+ MODEL_TENSOR.ATTN_ROT_EMBD,
+ ],
}
#
models/openai-moe-iswa.cpp
models/openelm.cpp
models/orion.cpp
+ models/pangu-embedded.cpp
models/phi2.cpp
models/phi3.cpp
models/plamo.cpp
{ LLM_ARCH_APERTUS, "apertus" },
{ LLM_ARCH_MINIMAX_M2, "minimax-m2" },
{ LLM_ARCH_COGVLM, "cogvlm" },
+ { LLM_ARCH_PANGU_EMBED, "pangu-embedded" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
},
},
+ {
+ LLM_ARCH_PANGU_EMBED,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_OUTPUT, "output" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
+ { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
+ { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
+ },
+ },
{
LLM_ARCH_COGVLM,
{
LLM_ARCH_APERTUS,
LLM_ARCH_MINIMAX_M2,
LLM_ARCH_COGVLM,
+ LLM_ARCH_PANGU_EMBED,
LLM_ARCH_UNKNOWN,
};
{ "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
{ "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS },
{ "grok-2", LLM_CHAT_TEMPLATE_GROK_2 },
+ { "pangu-embedded", LLM_CHAT_TEMPLATE_PANGU_EMBED },
};
llm_chat_template llm_chat_template_from_str(const std::string & name) {
return LLM_CHAT_TEMPLATE_SEED_OSS;
} else if (tmpl_contains("'Assistant: ' + message['content'] + '<|separator|>")) {
return LLM_CHAT_TEMPLATE_GROK_2;
+ } else if (tmpl_contains(LU8("[unused9]系统:[unused10]"))) {
+ return LLM_CHAT_TEMPLATE_PANGU_EMBED;
}
return LLM_CHAT_TEMPLATE_UNKNOWN;
}
if (add_ass) {
ss << "Assistant:";
}
+ }else if (tmpl == LLM_CHAT_TEMPLATE_PANGU_EMBED) {
+ // [unused9]系统:xxx[unused10]
+ // [unused9]用户:xxx[unused10]
+ // [unused9]助手:xxx[unused10]
+ // ...
+ for (size_t i = 0; i < chat.size(); ++i) {
+ const auto & msg = chat[i];
+ const std::string & role = msg->role;
+ const std::string & content = msg->content;
+
+ if (i == 0 && role != "system") {
+ ss << "[unused9]系统:[unused10]";
+ }
+
+ if (role == "system") {
+ ss << "[unused9]系统:" << content << "[unused10]";
+ } else if (role == "user") {
+ ss << "[unused9]用户:" << content << "[unused10]";
+ } else if (role == "assistant") {
+ ss << "[unused9]助手:" << content << "[unused10]";
+ } else if (role == "tool") {
+ ss << "[unused9]工具:" << content << "[unused10]";
+ } else if (role == "function") {
+ ss << "[unused9]方法:" << content << "[unused10]";
+ }
+ }
+ if (add_ass) {
+ ss << "[unused9]助手:";
+ }
} else {
// template not supported
return -1;
LLM_CHAT_TEMPLATE_KIMI_K2,
LLM_CHAT_TEMPLATE_SEED_OSS,
LLM_CHAT_TEMPLATE_GROK_2,
+ LLM_CHAT_TEMPLATE_PANGU_EMBED,
LLM_CHAT_TEMPLATE_UNKNOWN,
};
default: type = LLM_TYPE_UNKNOWN;
}
} break;
+ case LLM_ARCH_PANGU_EMBED:
+ {
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+ switch (hparams.n_layer) {
+ case 26: type = LLM_TYPE_1B; break; // openPangu-Embedded-1B-V1.1
+ case 34: type = LLM_TYPE_7B; break; // openPangu-Embedded-7B-V1.1
+ default: type = LLM_TYPE_UNKNOWN;
+ }
+ } break;
default: throw std::runtime_error("unsupported model architecture");
}
layer.visexp_ffn_up = create_tensor(tn(LLM_TENSOR_VISEXP_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}
} break;
+ case LLM_ARCH_PANGU_EMBED:
+ {
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+ // output
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+
+ // if output is NULL, init from the input tok embed
+ if (output == NULL) {
+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+ }
+
+ for (int i = 0; i < n_layer; ++i) {
+ auto & layer = layers[i];
+
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+ // weight tensors
+ layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0);
+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0);
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+
+ // bias tensors
+ layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd_head_k * n_head}, 0);
+ layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, 0);
+ layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, 0);
+ layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
+
+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+ if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) {
+ layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
+ layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
+ } else {
+ layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
+ }
+
+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
+ }
+ } break;
default:
throw std::runtime_error("unknown architecture");
}
{
llm = std::make_unique<llm_build_cogvlm>(*this, params);
} break;
+ case LLM_ARCH_PANGU_EMBED:
+ {
+ llm = std::make_unique<llm_build_pangu_embedded>(*this, params);
+ }break;
default:
GGML_ABORT("fatal error");
}
case LLM_ARCH_APERTUS:
case LLM_ARCH_MINIMAX_M2:
case LLM_ARCH_COGVLM:
+ case LLM_ARCH_PANGU_EMBED:
return LLAMA_ROPE_TYPE_NEOX;
case LLM_ARCH_QWEN2VL:
llm_build_orion(const llama_model & model, const llm_graph_params & params);
};
+struct llm_build_pangu_embedded : public llm_graph_context {
+ llm_build_pangu_embedded(const llama_model & model, const llm_graph_params & params);
+};
+
struct llm_build_phi2 : public llm_graph_context {
llm_build_phi2(const llama_model & model, const llm_graph_params & params);
};
--- /dev/null
+#include "models.h"
+
+
+llm_build_pangu_embedded::llm_build_pangu_embedded(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+ const int64_t n_embd_head = hparams.n_embd_head_v;
+
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+ ggml_tensor * cur;
+ ggml_tensor * inpL;
+
+ inpL = build_inp_embd(model.tok_embd);
+
+ // inp_pos - contains the positions
+ ggml_tensor * inp_pos = build_inp_pos();
+
+ auto * inp_attn = build_attn_inp_kv();
+
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+ for (int il = 0; il < n_layer; ++il) {
+ ggml_tensor * inpSA = inpL;
+
+ // norm
+ cur = build_norm(inpL,
+ model.layers[il].attn_norm, NULL,
+ LLM_NORM_RMS, il);
+ cb(cur, "attn_norm", il);
+
+ // self attention
+ {
+ // compute Q and K and RoPE them
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+ cb(Qcur, "Qcur", il);
+
+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+ cb(Kcur, "Kcur", il);
+
+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+ cb(Vcur, "Vcur", il);
+
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+ Qcur = ggml_rope_ext(
+ ctx0, Qcur, inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+
+ Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+
+ cb(Qcur, "Qcur", il);
+ cb(Kcur, "Kcur", il);
+ cb(Vcur, "Vcur", il);
+
+ cur = build_attn(inp_attn,
+ model.layers[il].wo, model.layers[il].bo,
+ Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+ }
+
+ if (il == n_layer - 1 && inp_out_ids) {
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
+ ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+ cb(ffn_inp, "ffn_inp", il);
+
+ // feed-forward network
+ cur = build_norm(ffn_inp,
+ model.layers[il].ffn_norm, NULL,
+ LLM_NORM_RMS, il);
+ cb(cur, "ffn_norm", il);
+
+ cur = build_ffn(cur,
+ model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
+ model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
+ model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+ NULL,
+ LLM_FFN_SILU, LLM_FFN_PAR, il);
+
+ cur = ggml_add(ctx0, cur, ffn_inp);
+ cb(cur, "ffn_out", il);
+
+ cur = build_cvec(cur, il);
+ cb(cur, "l_out", il);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ cur = inpL;
+
+ cur = build_norm(cur,
+ model.output_norm, NULL,
+ LLM_NORM_RMS, -1);
+
+ cb(cur, "result_norm", -1);
+ res->t_embd = cur;
+
+ // lm_head
+ cur = build_lora_mm(model.output, cur);
+
+ if (model.output_b != nullptr) {
+ cur = ggml_add(ctx0, cur, model.output_b);
+ }
+
+ cb(cur, "result_output", -1);
+ res->t_logits = cur;
+
+ ggml_build_forward_expand(gf, cur);
+}