]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
model : add openPangu-Embedded (#16941)
authorLi Pengzhan <redacted>
Wed, 5 Nov 2025 09:28:58 +0000 (17:28 +0800)
committerGitHub <redacted>
Wed, 5 Nov 2025 09:28:58 +0000 (10:28 +0100)
* Model: add openPangu-Embedded

* fixed according to reviewer's comments

* fixed the chat template check condition

* Apply suggestions from code review

change the chat-template check condition and some formatting issue

Co-authored-by: Sigbjørn Skjæret <redacted>
* whitespace cleanup

---------

Co-authored-by: Sigbjørn Skjæret <redacted>
convert_hf_to_gguf.py
gguf-py/gguf/constants.py
src/CMakeLists.txt
src/llama-arch.cpp
src/llama-arch.h
src/llama-chat.cpp
src/llama-chat.h
src/llama-model.cpp
src/models/models.h
src/models/pangu-embedded.cpp [new file with mode: 0644]

index c6f5ba6a04c54d4f5be6bd362ee724d9b9c7622e..222f6ed6dc40f92f954dc512185865ccb93c5e53 100755 (executable)
@@ -7187,6 +7187,42 @@ class MiniMaxM2Model(TextModel):
         return super().modify_tensors(data_torch, name, bid)
 
 
+@ModelBase.register("PanguEmbeddedForCausalLM")
+class PanguEmbeddedModel(TextModel):
+    model_arch = gguf.MODEL_ARCH.PANGU_EMBED
+
+    def set_vocab(self):
+        self._set_vocab_sentencepiece()
+
+        tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
+        if tokenizer_config_file.is_file():
+            with open(tokenizer_config_file, "r", encoding="utf-8") as f:
+                tokenizer_config_json = json.load(f)
+                if "add_prefix_space" in tokenizer_config_json:
+                    self.gguf_writer.add_add_space_prefix(tokenizer_config_json["add_prefix_space"])
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        hparams = self.hparams
+        self.gguf_writer.add_vocab_size(hparams["vocab_size"])
+
+        # PanguEmbedded's hparam loaded from config.json without head_dim
+        if (rope_dim := hparams.get("head_dim")) is None:
+            rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"]
+        self.gguf_writer.add_rope_dimension_count(rope_dim)
+
+        if hparams.get("head_dim") is None:
+            self.gguf_writer.add_key_length(rope_dim)
+            self.gguf_writer.add_value_length(rope_dim)
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        if name == "lm_head.weight":
+            if self.hparams.get("tie_word_embeddings", False):
+                logger.info("Skipping tied output layer 'lm_head.weight'")
+                return []
+        return [(self.map_tensor_name(name), data_torch)]
+
+
 @ModelBase.register("Dots1ForCausalLM")
 class Dots1Model(Qwen2MoeModel):
     model_arch = gguf.MODEL_ARCH.DOTS1
index 77e3b0650ff0b2e888407d6297e3ad43b9d86f14..6b4b6c5ab075d71a64a287b7453ec8a546d1fbb4 100644 (file)
@@ -426,6 +426,7 @@ class MODEL_ARCH(IntEnum):
     APERTUS          = auto()
     COGVLM           = auto()
     MINIMAXM2        = auto()
+    PANGU_EMBED      = auto()
 
 
 class VISION_PROJECTOR_TYPE(IntEnum):
@@ -793,6 +794,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.APERTUS:          "apertus",
     MODEL_ARCH.MINIMAXM2:        "minimax-m2",
     MODEL_ARCH.COGVLM:           "cogvlm",
+    MODEL_ARCH.PANGU_EMBED:      "pangu-embedded",
 }
 
 VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@@ -2958,6 +2960,20 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.VISEXP_UP,
         MODEL_TENSOR.VISEXP_DOWN,
     ],
+    MODEL_ARCH.PANGU_EMBED: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_GATE,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+    ],
     # TODO
 }
 
@@ -3013,6 +3029,10 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
     MODEL_ARCH.BAILINGMOE: [
         MODEL_TENSOR.ROPE_FREQS,
     ],
+    MODEL_ARCH.PANGU_EMBED: [
+        MODEL_TENSOR.ROPE_FREQS,
+        MODEL_TENSOR.ATTN_ROT_EMBD,
+    ],
 }
 
 #
index 832b58e315d095d5211977e59d8cc5ed0d97e2c6..630b2cddf67e8a534cb593f6f579be7c2d7f01e9 100644 (file)
@@ -99,6 +99,7 @@ add_library(llama
             models/openai-moe-iswa.cpp
             models/openelm.cpp
             models/orion.cpp
+            models/pangu-embedded.cpp
             models/phi2.cpp
             models/phi3.cpp
             models/plamo.cpp
index 7c7953b83dda8d20b8b38962b0eda87338288ad6..b7642b568dffb2bb9e4cd745ab5627a0a5068ba0 100644 (file)
@@ -107,6 +107,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_APERTUS,          "apertus"          },
     { LLM_ARCH_MINIMAX_M2,       "minimax-m2"       },
     { LLM_ARCH_COGVLM,           "cogvlm"           },
+    { LLM_ARCH_PANGU_EMBED,      "pangu-embedded"   },
     { LLM_ARCH_UNKNOWN,          "(unknown)"        },
 };
 
@@ -2377,6 +2378,23 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_FFN_EXP_PROBS_B,    "blk.%d.exp_probs_b" },
         },
     },
+    {
+        LLM_ARCH_PANGU_EMBED,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+        },
+    },
     {
         LLM_ARCH_COGVLM,
         {
index 3f893a2dc6916b97db9269b55ce520b680553113..a769dd1e85741ba442772a264e15a89785565f32 100644 (file)
@@ -111,6 +111,7 @@ enum llm_arch {
     LLM_ARCH_APERTUS,
     LLM_ARCH_MINIMAX_M2,
     LLM_ARCH_COGVLM,
+    LLM_ARCH_PANGU_EMBED,
     LLM_ARCH_UNKNOWN,
 };
 
index 0285006d73caa874527ffebf7e1c020ff0e9c635..fc6a6223cfe2f86b8d4bf7ff2b596e2fb6c945fb 100644 (file)
@@ -73,6 +73,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
     { "kimi-k2",           LLM_CHAT_TEMPLATE_KIMI_K2           },
     { "seed_oss",          LLM_CHAT_TEMPLATE_SEED_OSS          },
     { "grok-2",            LLM_CHAT_TEMPLATE_GROK_2            },
+    { "pangu-embedded",    LLM_CHAT_TEMPLATE_PANGU_EMBED       },
 };
 
 llm_chat_template llm_chat_template_from_str(const std::string & name) {
@@ -213,6 +214,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
         return LLM_CHAT_TEMPLATE_SEED_OSS;
     } else if (tmpl_contains("'Assistant: '  + message['content'] + '<|separator|>")) {
         return LLM_CHAT_TEMPLATE_GROK_2;
+    } else if (tmpl_contains(LU8("[unused9]系统:[unused10]"))) {
+        return LLM_CHAT_TEMPLATE_PANGU_EMBED;
     }
     return LLM_CHAT_TEMPLATE_UNKNOWN;
 }
@@ -813,6 +816,35 @@ int32_t llm_chat_apply_template(
         if (add_ass) {
             ss << "Assistant:";
         }
+    }else if (tmpl == LLM_CHAT_TEMPLATE_PANGU_EMBED) {
+        // [unused9]系统:xxx[unused10]
+        // [unused9]用户:xxx[unused10]
+        // [unused9]助手:xxx[unused10]
+        // ...
+        for (size_t i = 0; i < chat.size(); ++i) {
+            const auto & msg = chat[i];
+            const std::string & role = msg->role;
+            const std::string & content = msg->content;
+
+            if (i == 0 && role != "system") {
+                ss << "[unused9]系统:[unused10]";
+            }
+
+            if (role == "system") {
+                ss << "[unused9]系统:" << content << "[unused10]";
+            } else if (role == "user") {
+                ss << "[unused9]用户:" << content << "[unused10]";
+            } else if (role == "assistant") {
+                ss << "[unused9]助手:" << content << "[unused10]";
+            } else if (role == "tool") {
+                ss << "[unused9]工具:" << content << "[unused10]";
+            } else if (role == "function") {
+                ss << "[unused9]方法:" << content << "[unused10]";
+            }
+        }
+        if (add_ass) {
+            ss << "[unused9]助手:";
+        }
     } else {
         // template not supported
         return -1;
index da1b7c47997ca5af878de1a10b5214971419752e..684efb4d67f45b84f3cf6d96ac5909e22af29395 100644 (file)
@@ -53,6 +53,7 @@ enum llm_chat_template {
     LLM_CHAT_TEMPLATE_KIMI_K2,
     LLM_CHAT_TEMPLATE_SEED_OSS,
     LLM_CHAT_TEMPLATE_GROK_2,
+    LLM_CHAT_TEMPLATE_PANGU_EMBED,
     LLM_CHAT_TEMPLATE_UNKNOWN,
 };
 
index 896725466ce24649e024e3136715a3bf8affa119..1987135ca6a2e590ff4496a4e809e7c7579966d2 100644 (file)
@@ -2177,6 +2177,15 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     default: type = LLM_TYPE_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_PANGU_EMBED:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+                switch (hparams.n_layer) {
+                    case 26: type = LLM_TYPE_1B; break; // openPangu-Embedded-1B-V1.1
+                    case 34: type = LLM_TYPE_7B; break; // openPangu-Embedded-7B-V1.1
+                    default: type = LLM_TYPE_UNKNOWN;
+                }
+            } break;
         default: throw std::runtime_error("unsupported model architecture");
     }
 
@@ -6263,6 +6272,50 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         layer.visexp_ffn_up   = create_tensor(tn(LLM_TENSOR_VISEXP_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
                     }
                 } break;
+            case LLM_ARCH_PANGU_EMBED:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        // weight tensors
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+
+                        // bias tensors
+                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd_head_k * n_head}, 0);
+                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, 0);
+                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, 0);
+                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) {
+                            layer.rope_long  = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG,  "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
+                            layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
+                        } else {
+                            layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
+                        }
+
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
             default:
                 throw std::runtime_error("unknown architecture");
         }
@@ -7260,6 +7313,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
             {
                 llm = std::make_unique<llm_build_cogvlm>(*this, params);
             } break;
+        case LLM_ARCH_PANGU_EMBED:
+            {
+                llm = std::make_unique<llm_build_pangu_embedded>(*this, params);
+            }break;
         default:
             GGML_ABORT("fatal error");
     }
@@ -7479,6 +7536,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_APERTUS:
         case LLM_ARCH_MINIMAX_M2:
         case LLM_ARCH_COGVLM:
+        case LLM_ARCH_PANGU_EMBED:
             return LLAMA_ROPE_TYPE_NEOX;
 
         case LLM_ARCH_QWEN2VL:
index af203343a4d71ddd1f0d0706bb2254e59922ff59..2fffb382df2e5274e68518b2f6c50bcb4eae9b80 100644 (file)
@@ -361,6 +361,10 @@ struct llm_build_orion : public llm_graph_context {
     llm_build_orion(const llama_model & model, const llm_graph_params & params);
 };
 
+struct llm_build_pangu_embedded : public llm_graph_context {
+    llm_build_pangu_embedded(const llama_model & model, const llm_graph_params & params);
+};
+
 struct llm_build_phi2 : public llm_graph_context {
     llm_build_phi2(const llama_model & model, const llm_graph_params & params);
 };
diff --git a/src/models/pangu-embedded.cpp b/src/models/pangu-embedded.cpp
new file mode 100644 (file)
index 0000000..664572a
--- /dev/null
@@ -0,0 +1,121 @@
+#include "models.h"
+
+
+llm_build_pangu_embedded::llm_build_pangu_embedded(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+    const int64_t n_embd_head = hparams.n_embd_head_v;
+
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+    GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+    ggml_tensor * cur;
+    ggml_tensor * inpL;
+
+    inpL = build_inp_embd(model.tok_embd);
+
+    // inp_pos - contains the positions
+    ggml_tensor * inp_pos = build_inp_pos();
+
+    auto * inp_attn = build_attn_inp_kv();
+
+    ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+    for (int il = 0; il < n_layer; ++il) {
+        ggml_tensor * inpSA = inpL;
+
+        // norm
+        cur = build_norm(inpL,
+                model.layers[il].attn_norm, NULL,
+                LLM_NORM_RMS, il);
+        cb(cur, "attn_norm", il);
+
+        // self attention
+        {
+            // compute Q and K and RoPE them
+            ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+            Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+            cb(Qcur, "Qcur", il);
+
+            ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+            Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+            cb(Kcur, "Kcur", il);
+
+            ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+            Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+            cb(Vcur, "Vcur", il);
+
+            Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+            Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+            Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+            Qcur = ggml_rope_ext(
+                    ctx0, Qcur, inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                    );
+
+            Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                    );
+
+            cb(Qcur, "Qcur", il);
+            cb(Kcur, "Kcur", il);
+            cb(Vcur, "Vcur", il);
+
+            cur = build_attn(inp_attn,
+                    model.layers[il].wo, model.layers[il].bo,
+                    Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+        }
+
+        if (il == n_layer - 1 && inp_out_ids) {
+            cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+            inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+        }
+
+        ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+        cb(ffn_inp, "ffn_inp", il);
+
+        // feed-forward network
+        cur = build_norm(ffn_inp,
+                model.layers[il].ffn_norm, NULL,
+                LLM_NORM_RMS, il);
+        cb(cur, "ffn_norm", il);
+
+        cur = build_ffn(cur,
+                model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
+                model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                NULL,
+                LLM_FFN_SILU, LLM_FFN_PAR, il);
+
+        cur = ggml_add(ctx0, cur, ffn_inp);
+        cb(cur, "ffn_out", il);
+
+        cur = build_cvec(cur, il);
+        cb(cur, "l_out", il);
+
+        // input for next layer
+        inpL = cur;
+    }
+
+    cur = inpL;
+
+    cur = build_norm(cur,
+            model.output_norm, NULL,
+            LLM_NORM_RMS, -1);
+
+    cb(cur, "result_norm", -1);
+    res->t_embd = cur;
+
+    // lm_head
+    cur = build_lora_mm(model.output, cur);
+
+    if (model.output_b != nullptr) {
+        cur = ggml_add(ctx0, cur, model.output_b);
+    }
+
+    cb(cur, "result_output", -1);
+    res->t_logits = cur;
+
+    ggml_build_forward_expand(gf, cur);
+}