]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
model : add support for ERNIE 4.5 0.3B model (#14408)
authorWeizhao Ouyang <redacted>
Sat, 28 Jun 2025 14:08:21 +0000 (22:08 +0800)
committerGitHub <redacted>
Sat, 28 Jun 2025 14:08:21 +0000 (16:08 +0200)
Add Day-0 support for Baidu ERNIE 4.5 0.3B model.

Signed-off-by: Weizhao Ouyang <redacted>
convert_hf_to_gguf.py
gguf-py/gguf/constants.py
src/llama-arch.cpp
src/llama-arch.h
src/llama-model.cpp
src/llama-model.h

index aed595e259ed514bc49c6c0663aefc0cb4abd36b..c2c55166e7641bea7c0c321f61c4223a736081b0 100755 (executable)
@@ -2743,6 +2743,52 @@ class Qwen2Model(TextModel):
         yield from super().modify_tensors(data_torch, name, bid)
 
 
+@ModelBase.register("Ernie4_5_ForCausalLM")
+class Ernie4_5Model(TextModel):
+    model_arch = gguf.MODEL_ARCH.ERNIE4_5
+
+    def set_vocab(self):
+        self._set_vocab_sentencepiece()
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        num_heads = self.hparams["num_attention_heads"]
+        num_kv_heads = self.hparams["num_key_value_heads"]
+        head_dim = self.hparams["head_dim"]
+
+        if "ernie." in name:
+            name = name.replace("ernie.", "model.")
+        # split the qkv weights
+        # qkv_proj shape: [(num_heads + 2 * num_kv_heads) * head_dim, hidden_size]
+        if "qkv_proj" in name:
+            name_q = name.replace("qkv_proj.weight", "q_proj.weight")
+            name_k = name.replace("qkv_proj.weight", "k_proj.weight")
+            name_v = name.replace("qkv_proj.weight", "v_proj.weight")
+            total_q_dim = num_heads * head_dim
+            total_k_dim = num_kv_heads * head_dim
+            total_v_dim = num_kv_heads * head_dim
+            q_proj_weight, k_proj_weight, v_proj_weight = data_torch.split([total_q_dim, total_k_dim, total_v_dim], dim=0)
+            return [
+                (self.map_tensor_name(name_q), q_proj_weight),
+                (self.map_tensor_name(name_k), k_proj_weight),
+                (self.map_tensor_name(name_v), v_proj_weight)
+            ]
+        # split the up_gate_proj into gate and up
+        # up_gate_proj shape: [2 * intermediate_size, hidden_size]
+        if "up_gate_proj" in name:
+            name_up = name.replace("up_gate_proj.weight", "up_proj.weight")
+            name_gate = name.replace("up_gate_proj.weight", "gate_proj.weight")
+            dim_half = data_torch.shape[0] // 2
+            gate_proj_weight, up_proj_weight = data_torch.split(dim_half, dim=0)
+            return [
+                (self.map_tensor_name(name_gate), gate_proj_weight),
+                (self.map_tensor_name(name_up), up_proj_weight)
+            ]
+        return [(self.map_tensor_name(name), data_torch)]
+
+
 @ModelBase.register(
     "Qwen2VLModel",
     "Qwen2VLForConditionalGeneration",
index fb75143b0b54586efb100af1add9568986a5dd49..b5ba933cb0c61c4c677c504c5833ae3adf6743a4 100644 (file)
@@ -354,6 +354,7 @@ class MODEL_ARCH(IntEnum):
     BAILINGMOE       = auto()
     DOTS1            = auto()
     ARCEE            = auto()
+    ERNIE4_5         = auto()
 
 
 class VISION_PROJECTOR_TYPE(IntEnum):
@@ -654,6 +655,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.BAILINGMOE:       "bailingmoe",
     MODEL_ARCH.DOTS1:            "dots1",
     MODEL_ARCH.ARCEE:            "arcee",
+    MODEL_ARCH.ERNIE4_5:         "ernie4_5",
 }
 
 VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@@ -2177,6 +2179,20 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_DOWN,
         MODEL_TENSOR.FFN_UP,
     ],
+    MODEL_ARCH.ERNIE4_5: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_GATE,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+    ],
     # TODO
 }
 
index 435e3b9ba3db8b50521cea4a384b137b014d3e7c..aa21108a4bd79009723842c43d8324b4b0ff07f7 100644 (file)
@@ -76,6 +76,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_BAILINGMOE,       "bailingmoe"       },
     { LLM_ARCH_DOTS1,            "dots1"            },
     { LLM_ARCH_ARCEE,            "arcee"            },
+    { LLM_ARCH_ERNIE4_5,         "ernie4_5"         },
     { LLM_ARCH_UNKNOWN,          "(unknown)"        },
 };
 
@@ -1658,6 +1659,23 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_FFN_EXP_PROBS_B,    "blk.%d.exp_probs_b" },
         }
     },
+    {
+        LLM_ARCH_ERNIE4_5,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,         "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,        "output_norm" },
+            { LLM_TENSOR_OUTPUT,             "output" },
+            { LLM_TENSOR_ATTN_NORM,          "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,             "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,             "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,             "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,           "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_NORM,           "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,           "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,           "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,             "blk.%d.ffn_up" },
+        },
+    },
     {
         LLM_ARCH_UNKNOWN,
         {
index 9181ad053f6b3e61307a096aace4990b870f1699..0771ec3ebadcdc970f2ff3e1c08afbe83d9f1319 100644 (file)
@@ -80,6 +80,7 @@ enum llm_arch {
     LLM_ARCH_BAILINGMOE,
     LLM_ARCH_DOTS1,
     LLM_ARCH_ARCEE,
+    LLM_ARCH_ERNIE4_5,
     LLM_ARCH_UNKNOWN,
 };
 
index fc39195ed5177a14e7d7a67fedaa1f0c110e712e..b15bf73c2a29ad4160a777f818c2f044f698770e 100644 (file)
@@ -47,6 +47,7 @@ const char * llm_type_name(llm_type type) {
         case LLM_TYPE_475M:          return "475M";
         case LLM_TYPE_770M:          return "770M";
         case LLM_TYPE_780M:          return "780M";
+        case LLM_TYPE_0_3B:          return "0.3B";
         case LLM_TYPE_0_5B:          return "0.5B";
         case LLM_TYPE_0_6B:          return "0.6B";
         case LLM_TYPE_1B:            return "1B";
@@ -1504,6 +1505,14 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     default: type = LLM_TYPE_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_ERNIE4_5:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+                switch (hparams.n_layer) {
+                    case 18: type = LLM_TYPE_0_3B; break;
+                    default: type = LLM_TYPE_UNKNOWN;
+                }
+            } break;
         default: throw std::runtime_error("unsupported model architecture");
     }
 
@@ -4344,6 +4353,40 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
 
                         layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
 
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
+            case LLM_ARCH_ERNIE4_5:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+
+                        // optional bias tensors
+                        layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     TENSOR_NOT_REQUIRED);
+                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
+                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED);
+                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd},     TENSOR_NOT_REQUIRED);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
                         layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
                         layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
                     }
@@ -14125,6 +14168,136 @@ struct llm_build_dots1 : public llm_graph_context {
     }
 };
 
+struct llm_build_ernie4_5 : public llm_graph_context {
+    llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified();
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            {
+                cur = build_norm(inpL,
+                        model.layers[il].attn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "attn_norm", il);
+            }
+
+            // self-attention
+            {
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, NULL,
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
 struct llm_build_arcee : public llm_graph_context {
     llm_build_arcee(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -14635,6 +14808,10 @@ llm_graph_result_ptr llama_model::build_graph(
             {
                 llm = std::make_unique<llm_build_arcee>(*this, params, gf);
             } break;
+        case LLM_ARCH_ERNIE4_5:
+            {
+                llm = std::make_unique<llm_build_ernie4_5>(*this, params, gf);
+            } break;
         default:
             GGML_ABORT("fatal error");
     }
@@ -14786,6 +14963,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_BAILINGMOE:
         case LLM_ARCH_NEO_BERT:
         case LLM_ARCH_ARCEE:
+        case LLM_ARCH_ERNIE4_5:
             return LLAMA_ROPE_TYPE_NORM;
 
         // the pairs of head values are offset by n_rot/2
index 40063b790d434d19425a207ede1b9bde129bd214..a958c5997a11b8e3447635ab635c31af50149fd7 100644 (file)
@@ -39,6 +39,7 @@ enum llm_type {
     LLM_TYPE_475M,
     LLM_TYPE_770M,
     LLM_TYPE_780M,
+    LLM_TYPE_0_3B,
     LLM_TYPE_0_5B,
     LLM_TYPE_0_6B,
     LLM_TYPE_1B,