def set_gguf_parameters(self):
self.gguf_writer.add_block_count(self.block_count)
- if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx", "n_positions"], optional=True)) is not None:
+ if (n_ctx := self.find_hparam(["max_position_embeddings", "n_ctx", "n_positions", "max_length"], optional=True)) is not None:
self.gguf_writer.add_context_length(n_ctx)
logger.info(f"gguf: context length = {n_ctx}")
raise ValueError(f"unknown tokenizer: {toktyp}")
+@ModelBase.register("NeoBERT", "NeoBERTLMHead", "NeoBERTForSequenceClassification")
+class NeoBert(BertModel):
+ model_arch = gguf.MODEL_ARCH.NEO_BERT
+
+ def set_gguf_parameters(self):
+ super().set_gguf_parameters()
+
+ # NeoBERT uses 2/3 of the intermediate size as feed forward length
+ self.gguf_writer.add_feed_forward_length(int(2 * self.hparams["intermediate_size"] / 3))
+ self.gguf_writer.add_rope_freq_base(10000.0) # default value for NeoBERT
+ self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
+
+ f_rms_eps = self.hparams.get("norm_eps", 1e-6) # default value for NeoBERT
+ self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps)
+ logger.info(f"gguf: rms norm epsilon = {f_rms_eps}")
+
+ self.gguf_writer.add_pooling_type(gguf.PoolingType.CLS) # https://huggingface.co/chandar-lab/NeoBERT#how-to-use
+
+ def modify_tensors(self, data_torch, name, bid):
+ if name.startswith("decoder."):
+ return []
+
+ if name.startswith("model."):
+ name = name[6:]
+
+ return super().modify_tensors(data_torch, name, bid)
+
+
@ModelBase.register("XLMRobertaModel", "XLMRobertaForSequenceClassification")
class XLMRobertaModel(BertModel):
model_arch = gguf.MODEL_ARCH.BERT
"model.embeddings", # rwkv7
"model.word_embeddings", # bailingmoe
"language_model.model.embed_tokens", # llama4
+ "encoder", # neobert
),
# Token type embeddings
"rwkv.blocks.{bid}.ln1", # rwkv6
"model.layers.{bid}.ln1", # rwkv7
"model.layers.{bid}.input_layernorm", # llama4
+ "transformer_encoder.{bid}.attention_norm", # neobert
),
# Attention norm 2
"model.layers.{bid}.self_attn.qkv_proj", # phi3
"encoder.layers.{bid}.self_attention.query_key_value", # chatglm
"transformer.layers.{bid}.attn.qkv_proj", # openelm
+ "transformer_encoder.{bid}.qkv", # neobert
),
# Attention query
"transformer.layers.{bid}.attn.out_proj", # openelm
"transformer.h.{bid}.attn.attention.out_proj", # exaone
"model.layers.{bid}.self_attn.o_proj", # llama4
+ "transformer_encoder.{bid}.wo", # neobert
),
# Attention output norm
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
"transformer.layers.{bid}.ffn_norm", # openelm
"model.layers.{bid}.post_attention_layernorm", # llama4
+ "transformer_encoder.{bid}.ffn_norm", # neobert
),
# Post feed-forward norm
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
"transformer.h.{bid}.mlp.c_fc_1", # exaone
"model.layers.{bid}.feed_forward.up_proj", # llama4
+ "transformer_encoder.{bid}.ffn.w12", # neobert
),
MODEL_TENSOR.FFN_UP_EXP: (
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
"model.layers.h.{bid}.mlp.c_proj", # exaone
"model.layers.{bid}.feed_forward.down_proj", # llama4
+ "transformer_encoder.{bid}.ffn.w3", # neobert
),
MODEL_TENSOR.FFN_DOWN_EXP: (
# TODO: these do not belong to block_mappings_cfg - move them to mappings_cfg
MODEL_TENSOR.ENC_OUTPUT_NORM: (
"encoder.final_layer_norm", # t5
+ "layer_norm", # neobert
),
MODEL_TENSOR.CLS: (
"classifier", # jina
"classifier.dense", # roberta
"pre_classifier", # distillbert
+ "dense", # neobert
),
MODEL_TENSOR.CLS_OUT: (
}
}
} break;
+ case LLM_ARCH_NEO_BERT:
+ {
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+ ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn);
+ ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type);
+
+ if (hparams.n_layer == 28) {
+ type = LLM_TYPE_250M;
+ }
+ } break;
case LLM_ARCH_BLOOM:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i), {n_embd}, 0);
}
} break;
+ case LLM_ARCH_NEO_BERT:
+ {
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+ cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED);
+ cls_b = create_tensor(tn(LLM_TENSOR_CLS, "bias"), {n_embd}, TENSOR_NOT_REQUIRED);
+
+ cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
+ cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
+
+ output_norm_enc = create_tensor(tn(LLM_TENSOR_ENC_OUTPUT_NORM, "weight"), {n_embd}, 0);
+
+ for (int i = 0; i < n_layer; ++i) {
+ auto & layer = layers[i];
+
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+ layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
+ layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0);
+
+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff*2}, 0);
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+ }
+ } break;
case LLM_ARCH_JINA_BERT_V2:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // word_embeddings
}
};
+struct llm_build_neo_bert : public llm_graph_context {
+ llm_build_neo_bert(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+ const int64_t n_embd_head = hparams.n_embd_head_v;
+ const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
+
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+ ggml_tensor * cur;
+ ggml_tensor * inpL;
+ ggml_tensor * inp_pos = build_inp_pos();
+
+ // construct input embeddings (token, type, position)
+ inpL = build_inp_embd(model.tok_embd);
+ cb(inpL, "inp_embd", -1);
+
+ auto * inp_attn = build_attn_inp_no_cache();
+
+ // iterate layers
+ for (int il = 0; il < n_layer; ++il) {
+ ggml_tensor * cur = inpL;
+
+ ggml_tensor * Qcur;
+ ggml_tensor * Kcur;
+ ggml_tensor * Vcur;
+
+ // pre-norm
+ cur = build_norm(inpL,
+ model.layers[il].attn_norm, NULL,
+ LLM_NORM_RMS, il);
+
+ // self-attention
+ cur = build_lora_mm(model.layers[il].wqkv, cur);
+ cb(cur, "wqkv", il);
+
+ Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+ Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+ Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+ // RoPE
+ Qcur = ggml_rope_ext(
+ ctx0, Qcur, inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+
+ Kcur = ggml_rope_ext(
+ ctx0, Kcur, inp_pos, nullptr,
+ n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow
+ );
+
+ cb(Qcur, "Qcur", il);
+ cb(Kcur, "Kcur", il);
+ cb(Vcur, "Vcur", il);
+
+ cur = build_attn(inp_attn, gf,
+ model.layers[il].wo, nullptr,
+ Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+ cb(cur, "kqv_out", il);
+
+ if (il == n_layer - 1 && pooling_type == LLAMA_POOLING_TYPE_NONE) {
+ // skip computing output for unused tokens
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+ }
+
+ // re-add the layer input
+ cur = ggml_add(ctx0, cur, inpL);
+
+ ggml_tensor * ffn_inp = cur;
+ cb(ffn_inp, "ffn_inp", il);
+
+ // pre-norm
+ cur = build_norm(ffn_inp,
+ model.layers[il].ffn_norm, NULL,
+ LLM_NORM_RMS, il);
+ cb(cur, "ffn_norm", il);
+
+ // feed-forward network
+ cur = build_ffn(cur,
+ model.layers[il].ffn_up,
+ NULL, NULL, NULL, NULL, NULL,
+ model.layers[il].ffn_down,
+ NULL, NULL, NULL,
+ LLM_FFN_SWIGLU, LLM_FFN_SEQ, il);
+
+ // attentions bypass the intermediate layer
+ cur = ggml_add(ctx0, cur, ffn_inp);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ cur = inpL;
+
+ cur = build_norm(cur,
+ model.output_norm_enc, NULL,
+ LLM_NORM_RMS, -1);
+
+ cb(cur, "result_embd", -1);
+ res->t_embd = cur;
+
+ ggml_build_forward_expand(gf, cur);
+ }
+};
+
struct llm_build_bloom : public llm_graph_context {
llm_build_bloom(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
case LLM_ARCH_JINA_BERT_V2:
case LLM_ARCH_NOMIC_BERT:
case LLM_ARCH_NOMIC_BERT_MOE:
+ case LLM_ARCH_NEO_BERT:
case LLM_ARCH_WAVTOKENIZER_DEC:
{
res = nullptr;
{
llm = std::make_unique<llm_build_bert>(*this, params, gf);
} break;
+ case LLM_ARCH_NEO_BERT:
+ {
+ llm = std::make_unique<llm_build_neo_bert>(*this, params, gf);
+ } break;
case LLM_ARCH_BLOOM:
{
llm = std::make_unique<llm_build_bloom>(*this, params, gf);
case LLM_ARCH_GRANITE_MOE:
case LLM_ARCH_CHAMELEON:
case LLM_ARCH_BAILINGMOE:
+ case LLM_ARCH_NEO_BERT:
case LLM_ARCH_ARCEE:
return LLAMA_ROPE_TYPE_NORM;