]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
model : add NeoBERT (#14164)
authorĐinh Trọng Huy <redacted>
Mon, 16 Jun 2025 12:53:41 +0000 (21:53 +0900)
committerGitHub <redacted>
Mon, 16 Jun 2025 12:53:41 +0000 (14:53 +0200)
* convert neobert model to gguf

* add inference graph

* fix flake8 lint

* followed reviewer suggestions

Co-authored-by: Georgi Gerganov <redacted>
* follow reviewers suggestions

Co-authored-by: Georgi Gerganov <redacted>
* override NeoBERT feed-forward length

---------

Co-authored-by: dinhhuy <redacted>
Co-authored-by: Georgi Gerganov <redacted>
convert_hf_to_gguf.py
gguf-py/gguf/constants.py
gguf-py/gguf/tensor_mapping.py
src/llama-arch.cpp
src/llama-arch.h
src/llama-model.cpp

index 2232a7d82349ea9a334e3199deab22479989330d..58e455ae645ede217f88b915c05e7931ed1ffa79 100755 (executable)
@@ -519,7 +519,7 @@ class TextModel(ModelBase):
     def set_gguf_parameters(self):
         self.gguf_writer.add_block_count(self.block_count)
 
-        if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx", "n_positions"], optional=True)) is not None:
+        if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx", "n_positions", "max_length"], optional=True)) is not None:
             self.gguf_writer.add_context_length(n_ctx)
             logger.info(f"gguf: context length = {n_ctx}")
 
@@ -4076,6 +4076,34 @@ class NomicBertModel(BertModel):
         raise ValueError(f"unknown tokenizer: {toktyp}")
 
 
+@ModelBase.register("NeoBERT", "NeoBERTLMHead", "NeoBERTForSequenceClassification")
+class NeoBert(BertModel):
+    model_arch = gguf.MODEL_ARCH.NEO_BERT
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+
+        # NeoBERT uses 2/3 of the intermediate size as feed forward length
+        self.gguf_writer.add_feed_forward_length(int(2 * self.hparams["intermediate_size"] / 3))
+        self.gguf_writer.add_rope_freq_base(10000.0)  # default value for NeoBERT
+        self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
+
+        f_rms_eps = self.hparams.get("norm_eps", 1e-6)  # default value for NeoBERT
+        self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps)
+        logger.info(f"gguf: rms norm epsilon = {f_rms_eps}")
+
+        self.gguf_writer.add_pooling_type(gguf.PoolingType.CLS) # https://huggingface.co/chandar-lab/NeoBERT#how-to-use
+
+    def modify_tensors(self, data_torch, name, bid):
+        if name.startswith("decoder."):
+            return []
+
+        if name.startswith("model."):
+            name = name[6:]
+
+        return super().modify_tensors(data_torch, name, bid)
+
+
 @ModelBase.register("XLMRobertaModel", "XLMRobertaForSequenceClassification")
 class XLMRobertaModel(BertModel):
     model_arch = gguf.MODEL_ARCH.BERT
index 9b2143c7c2eaa40e1ed3232056fbce57c0b1ae15..834a1d5e1a97ed98f01196e0a0542afa95739b8c 100644 (file)
@@ -291,6 +291,7 @@ class MODEL_ARCH(IntEnum):
     BERT             = auto()
     NOMIC_BERT       = auto()
     NOMIC_BERT_MOE   = auto()
+    NEO_BERT         = auto()
     JINA_BERT_V2     = auto()
     BLOOM            = auto()
     STABLELM         = auto()
@@ -573,6 +574,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.BERT:             "bert",
     MODEL_ARCH.NOMIC_BERT:       "nomic-bert",
     MODEL_ARCH.NOMIC_BERT_MOE:   "nomic-bert-moe",
+    MODEL_ARCH.NEO_BERT:         "neo-bert",
     MODEL_ARCH.JINA_BERT_V2:     "jina-bert-v2",
     MODEL_ARCH.BLOOM:            "bloom",
     MODEL_ARCH.STABLELM:         "stablelm",
@@ -1081,6 +1083,18 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_UP_EXP,
         MODEL_TENSOR.LAYER_OUT_NORM,
     ],
+    MODEL_ARCH.NEO_BERT: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_QKV,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+        MODEL_TENSOR.ENC_OUTPUT_NORM,
+        MODEL_TENSOR.CLS,
+        MODEL_TENSOR.CLS_OUT,
+    ],
     MODEL_ARCH.JINA_BERT_V2: [
         MODEL_TENSOR.TOKEN_EMBD,
         MODEL_TENSOR.TOKEN_EMBD_NORM,
index 5e3f01754bf0794f9b7b7f79b127aa066ef95d43..79f044d2a5945236b613ac3733ccc2b60ebd9ba4 100644 (file)
@@ -31,6 +31,7 @@ class TensorNameMap:
             "model.embeddings",                          # rwkv7
             "model.word_embeddings",                     # bailingmoe
             "language_model.model.embed_tokens",         # llama4
+            "encoder",                                   # neobert
         ),
 
         # Token type embeddings
@@ -134,6 +135,7 @@ class TensorNameMap:
             "rwkv.blocks.{bid}.ln1",                                # rwkv6
             "model.layers.{bid}.ln1",                               # rwkv7
             "model.layers.{bid}.input_layernorm",                   # llama4
+            "transformer_encoder.{bid}.attention_norm",             # neobert
         ),
 
         # Attention norm 2
@@ -161,6 +163,7 @@ class TensorNameMap:
             "model.layers.{bid}.self_attn.qkv_proj",                               # phi3
             "encoder.layers.{bid}.self_attention.query_key_value",                 # chatglm
             "transformer.layers.{bid}.attn.qkv_proj",                              # openelm
+            "transformer_encoder.{bid}.qkv",                                       # neobert
         ),
 
         # Attention query
@@ -236,6 +239,7 @@ class TensorNameMap:
             "transformer.layers.{bid}.attn.out_proj",                       # openelm
             "transformer.h.{bid}.attn.attention.out_proj",                  # exaone
             "model.layers.{bid}.self_attn.o_proj",                          # llama4
+            "transformer_encoder.{bid}.wo",                                 # neobert
         ),
 
         # Attention output norm
@@ -276,6 +280,7 @@ class TensorNameMap:
             "encoder.layers.{bid}.post_attention_layernorm",                 # chatglm
             "transformer.layers.{bid}.ffn_norm",                             # openelm
             "model.layers.{bid}.post_attention_layernorm",                   # llama4
+            "transformer_encoder.{bid}.ffn_norm",                            # neobert
         ),
 
         # Post feed-forward norm
@@ -340,6 +345,7 @@ class TensorNameMap:
             "encoder.layers.{bid}.mlp.dense_h_to_4h",                 # chatglm
             "transformer.h.{bid}.mlp.c_fc_1",                         # exaone
             "model.layers.{bid}.feed_forward.up_proj",                # llama4
+            "transformer_encoder.{bid}.ffn.w12",                      # neobert
         ),
 
         MODEL_TENSOR.FFN_UP_EXP: (
@@ -422,6 +428,7 @@ class TensorNameMap:
             "encoder.layers.{bid}.mlp.dense_4h_to_h",                 # chatglm
             "model.layers.h.{bid}.mlp.c_proj",                        # exaone
             "model.layers.{bid}.feed_forward.down_proj",              # llama4
+            "transformer_encoder.{bid}.ffn.w3",                       # neobert
         ),
 
         MODEL_TENSOR.FFN_DOWN_EXP: (
@@ -832,12 +839,14 @@ class TensorNameMap:
         # TODO: these do not belong to block_mappings_cfg - move them to mappings_cfg
         MODEL_TENSOR.ENC_OUTPUT_NORM: (
             "encoder.final_layer_norm", # t5
+            "layer_norm",               # neobert
         ),
 
         MODEL_TENSOR.CLS: (
             "classifier",       # jina
             "classifier.dense", # roberta
             "pre_classifier",   # distillbert
+            "dense",            # neobert
         ),
 
         MODEL_TENSOR.CLS_OUT: (
index a3e7c861ca02f45e36195389f9bbb5fe58c13927..de8d289cf967e989cbd3fcca85639b1b11c4ba32 100644 (file)
@@ -20,6 +20,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_BERT,             "bert"             },
     { LLM_ARCH_NOMIC_BERT,       "nomic-bert"       },
     { LLM_ARCH_NOMIC_BERT_MOE,   "nomic-bert-moe"   },
+    { LLM_ARCH_NEO_BERT,         "neo-bert"         },
     { LLM_ARCH_JINA_BERT_V2,     "jina-bert-v2"     },
     { LLM_ARCH_BLOOM,            "bloom"            },
     { LLM_ARCH_STABLELM,         "stablelm"         },
@@ -514,6 +515,21 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
         },
     },
+    {
+        LLM_ARCH_NEO_BERT,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+            { LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" },
+            { LLM_TENSOR_CLS,             "cls" },
+            { LLM_TENSOR_CLS_OUT,         "cls.output" },
+        },
+    },
     {
         LLM_ARCH_JINA_BERT_V2,
         {
index 168fdcb401cfda294d6a497c06ea2557f4d530d7..3e8a61da3c13e38fc1711e003d2f2b6ba3f59393 100644 (file)
@@ -24,6 +24,7 @@ enum llm_arch {
     LLM_ARCH_BERT,
     LLM_ARCH_NOMIC_BERT,
     LLM_ARCH_NOMIC_BERT_MOE,
+    LLM_ARCH_NEO_BERT,
     LLM_ARCH_JINA_BERT_V2,
     LLM_ARCH_BLOOM,
     LLM_ARCH_STABLELM,
index dcc8b0be725633f740e3f58d8dbd8b4dc616b605..a5eb122f998d85cbcd309953dd71518271578ceb 100644 (file)
@@ -749,6 +749,16 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     }
                 }
             } break;
+        case LLM_ARCH_NEO_BERT:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+                ml.get_key(LLM_KV_ATTENTION_CAUSAL,            hparams.causal_attn);
+                ml.get_key(LLM_KV_POOLING_TYPE,                hparams.pooling_type);
+
+                if (hparams.n_layer == 28) {
+                    type = LLM_TYPE_250M;
+                }
+            } break;
         case LLM_ARCH_BLOOM:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@@ -2212,6 +2222,32 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i),   {n_embd}, 0);
                     }
                 } break;
+            case LLM_ARCH_NEO_BERT:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, 0);
+
+                    cls   = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED);
+                    cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"),   {n_embd},         TENSOR_NOT_REQUIRED);
+
+                    cls_out   = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
+                    cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"),   {hparams.n_cls_out},         TENSOR_NOT_REQUIRED);
+
+                    output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
+                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff*2}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                    }
+                } break;
             case LLM_ARCH_JINA_BERT_V2:
                 {
                     tok_embd  = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, 0); // word_embeddings
@@ -6182,6 +6218,117 @@ struct llm_build_bert : public llm_graph_context {
     }
 };
 
+struct llm_build_neo_bert : public llm_graph_context {
+    llm_build_neo_bert(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        // construct input embeddings (token, type, position)
+        inpL = build_inp_embd(model.tok_embd);
+        cb(inpL, "inp_embd", -1);
+
+        auto * inp_attn = build_attn_inp_no_cache();
+
+        // iterate layers
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * cur = inpL;
+
+            ggml_tensor * Qcur;
+            ggml_tensor * Kcur;
+            ggml_tensor * Vcur;
+
+            // pre-norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+
+            // self-attention
+            cur = build_lora_mm(model.layers[il].wqkv, cur);
+            cb(cur, "wqkv", il);
+
+            Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+            Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+            Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+
+            Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+            Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+            Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+            // RoPE
+            Qcur = ggml_rope_ext(
+                    ctx0, Qcur, inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                    );
+
+            Kcur = ggml_rope_ext(
+                    ctx0, Kcur, inp_pos, nullptr,
+                    n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                    );
+
+            cb(Qcur, "Qcur", il);
+            cb(Kcur, "Kcur", il);
+            cb(Vcur, "Vcur", il);
+
+            cur = build_attn(inp_attn, gf,
+                    model.layers[il].wo, nullptr,
+                    Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            cb(cur, "kqv_out", il);
+
+            if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+            }
+
+            // re-add the layer input
+            cur = ggml_add(ctx0, cur, inpL);
+
+            ggml_tensor * ffn_inp = cur;
+            cb(ffn_inp, "ffn_inp", il);
+
+            // pre-norm
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            // feed-forward network
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,
+                    NULL, NULL, NULL, NULL, NULL,
+                    model.layers[il].ffn_down,
+                    NULL, NULL, NULL,
+                    LLM_FFN_SWIGLU, LLM_FFN_SEQ, il);
+
+            // attentions bypass the intermediate layer
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm_enc, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_embd", -1);
+        res->t_embd = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
 struct llm_build_bloom : public llm_graph_context {
     llm_build_bloom(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -13595,6 +13742,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
         case LLM_ARCH_JINA_BERT_V2:
         case LLM_ARCH_NOMIC_BERT:
         case LLM_ARCH_NOMIC_BERT_MOE:
+        case LLM_ARCH_NEO_BERT:
         case LLM_ARCH_WAVTOKENIZER_DEC:
             {
                 res = nullptr;
@@ -13703,6 +13851,10 @@ llm_graph_result_ptr llama_model::build_graph(
             {
                 llm = std::make_unique<llm_build_bert>(*this, params, gf);
             } break;
+        case LLM_ARCH_NEO_BERT:
+            {
+                llm = std::make_unique<llm_build_neo_bert>(*this, params, gf);
+            } break;
         case LLM_ARCH_BLOOM:
             {
                 llm = std::make_unique<llm_build_bloom>(*this, params, gf);
@@ -14082,6 +14234,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_GRANITE_MOE:
         case LLM_ARCH_CHAMELEON:
         case LLM_ARCH_BAILINGMOE:
+        case LLM_ARCH_NEO_BERT:
         case LLM_ARCH_ARCEE:
             return LLAMA_ROPE_TYPE_NORM;