]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
model : Granite Embedding support (#15641)
authorRyan Mangeno <redacted>
Mon, 22 Dec 2025 23:28:19 +0000 (18:28 -0500)
committerGitHub <redacted>
Mon, 22 Dec 2025 23:28:19 +0000 (00:28 +0100)
ModernBERT but without `head.norm` so will currently fail to convert and run any other ModernBERT models, PRs with `head.norm` support welcome!

* constants and tensor mappings for modern bert support, model not supported yet but working on getting conversion to work for encoder only

* conversion now working, hf -> gguf

* working on support, now working on building graph

* some cleanup

* cleanup

* continuing

* correct tensor shape for qkv

* fixed tensor mappings and working on buildin graph

* tensor debugging now works -> (llama-eval-callback), instead of simulated gate split with views, GEGLU is now used which does exactly this

* cleanup

* cleanup

* cleanup

* more cleanup

* ubatch issues, the assert for checking equal seqs in llama-graph.cpp when building attention  keeps failing, setting ubatch size to 1 when running llama-embedding with --ubatch-size 1 makes it work, but needs to be looked into more

* added cls token per previous modern bert attempt, still working on checking out the rest

* fixed pre tokenizer and still working through previous pr

* working through previous attemp, implimented more accurate conversion per previous attempt, added local sliding window attention that alternates every third layer

* fixed pre tokenizer

* working on swa with local and global alternating attention

* some cleanup and now fails on build attn

* starting to work, and some cleanup, currently failing on last layer construction in graph build

* alternating rope implemented and modern bert graph build succeeds

* fixed asser for equal ubatch seq

* cleanup

* added mask check in vocab

* fixed alternating rope, the hparams.rope_freq_base_train and hparams.rope_freq_base_train_swa were the same and i set them to correct values

* reuse variable

* removed repeat

* standard swa method can be used instead of a new enum being LLAMA_SWA_TYPE_LOCAL

* correct swa layer indexing, is supposed to be 0, 3, 6 ... instead of 1, 4, 7 ...

* more modular hparam setting

* replaced attn out norm with ffn_norm and cosine similarity between hf embds and llama.cpp embds went way up, from 0.05 to 0.24, replaced the cacheless kv with swa todo per the previous conversion

* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update convert_hf_to_gguf_update.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-vocab.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-graph.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-arch.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* removed redundant hparam set

* enums for model sizes

* conversion for modern-bert model supported rather than just granite-small

* Update src/llama-model.cpp

Co-authored-by: Gabe Goodhart <redacted>
* Update src/llama-model.cpp

Co-authored-by: Gabe Goodhart <redacted>
* fixed ordering of enum for freq_base_swa

* fixed where I added residual, now gives much much better embeddings~

* readded cacheless logic

* removing whitespace

* conversion now working for swa pattern - dense every n layers

* modern bert put into seperate src file

* removing whitespace

* fixed whitespace and newline errors in editorconfig job

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* better naming convention, n_swa_pattern -> swa_period

* reusing sliding_window_pattern key rather than making new dense_every_n_layers key, and adding writing and reading support

* fixing pyright type-check fail

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update gguf-py/gguf/gguf_writer.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-hparams.h

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model-saver.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/models/modern-bert.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/models/modern-bert.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/models/modern-bert.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update gguf-py/gguf/gguf_writer.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/models/modern-bert.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/models/modern-bert.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model-loader.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model-loader.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model-loader.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* added descriptions in llama-model

* fixed tensor mappings for conversion

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <redacted>
* mapping name for size

* nits

* unused

---------

Co-authored-by: Sigbjørn Skjæret <redacted>
Co-authored-by: Gabe Goodhart <redacted>
15 files changed:
convert_hf_to_gguf.py
convert_hf_to_gguf_update.py
gguf-py/gguf/constants.py
gguf-py/gguf/gguf_writer.py
gguf-py/gguf/tensor_mapping.py
src/CMakeLists.txt
src/llama-arch.cpp
src/llama-arch.h
src/llama-model-loader.cpp
src/llama-model-loader.h
src/llama-model.cpp
src/llama-model.h
src/llama-vocab.cpp
src/models/models.h
src/models/modern-bert.cpp [new file with mode: 0644]

index 22f703e6ad40a7787479a57781da82a15b442620..16c5acf346d5d862ac54e3b5c1aa667dfcb4cf20 100755 (executable)
@@ -1212,6 +1212,9 @@ class TextModel(ModelBase):
         if chkhsh == "a1e163ecab2e718a4c829d1148b6e86824ec36163bb71941c3dca9cd5ac25756":
             # ref: https://huggingface.co/JetBrains/Mellum-4b-base
             res = "mellum"
+        if chkhsh == "a0b64b4385f123663873756336c085744376d015ff328bb1d901598f63c44152":
+            # ref: https://huggingface.co/answerdotai/ModernBERT-base
+            res = "modern-bert"
         if chkhsh == "49fc0303c9e0d2c2c565c510f64b2d9b271276acdcdadff733249eda9f7d59df":
             # ref: https://huggingface.co/arcee-ai/Trinity-Tokenizer
             res = "afmoe"
@@ -9999,6 +10002,36 @@ class SmallThinkerModel(TextModel):
                 raise ValueError(f"Unprocessed experts: {experts}")
 
 
+@ModelBase.register("ModernBertModel", "ModernBertForMaskedLM", "ModernBertForSequenceClassification")
+class ModernBertModel(BertModel):
+    model_arch = gguf.MODEL_ARCH.MODERN_BERT
+
+    def set_vocab(self):
+        self.gguf_writer.add_add_bos_token(True)
+        self.gguf_writer.add_add_eos_token(True)
+        self.gguf_writer.add_add_sep_token(True)
+        self._set_vocab_gpt2()
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        self.gguf_writer.add_sliding_window(self.hparams["local_attention"])
+        if (sliding_window_pattern := self.hparams.get("global_attn_every_n_layers")) is not None:
+            self.gguf_writer.add_sliding_window_pattern(sliding_window_pattern)
+        self.gguf_writer.add_rope_freq_base_swa(self.rope_parameters.get("sliding_attention", {"rope_theta": self.hparams.get("local_rope_theta")})["rope_theta"])
+        self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
+        self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        # these layers act as MLM head, so we don't need them
+        if name.startswith("decoder."):
+            return []
+
+        if name.startswith("model."):
+            name = name[6:]
+
+        return super().modify_tensors(data_torch, name, bid)
+
+
 @ModelBase.register("ApertusForCausalLM")
 class ApertusModel(LlamaModel):
     model_arch = gguf.MODEL_ARCH.APERTUS
index 5e8456a7ea51bba5dae43e010b6125bbc76769e1..4378378309f36776b1eb9c3456ae3d1fb588bdc6 100755 (executable)
@@ -139,6 +139,7 @@ models = [
     {"name": "lfm2",             "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"},
     {"name": "exaone4",          "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-4.0-32B", },
     {"name": "mellum",           "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/JetBrains/Mellum-4b-base", },
+    {"name": "modern-bert",      "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/answerdotai/ModernBERT-base", },
     {"name": "afmoe",            "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/arcee-ai/Trinity-Tokenizer", },
     {"name": "bailingmoe2",      "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-mini-base-2.0", },
     {"name": "granite-docling",  "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
index cab8f2901ae60d0c62e40d806b95356d5505efb6..41d3bd4faf2389e1e2171f573feb118da041a3f7 100644 (file)
@@ -181,6 +181,7 @@ class Keys:
         DIMENSION_COUNT          = "{arch}.rope.dimension_count"
         DIMENSION_SECTIONS       = "{arch}.rope.dimension_sections"
         FREQ_BASE                = "{arch}.rope.freq_base"
+        FREQ_BASE_SWA            = "{arch}.rope.freq_base_swa"
         SCALING_TYPE             = "{arch}.rope.scaling.type"
         SCALING_FACTOR           = "{arch}.rope.scaling.factor"
         SCALING_ATTN_FACTOR      = "{arch}.rope.scaling.attn_factor"
@@ -354,6 +355,7 @@ class MODEL_ARCH(IntEnum):
     STARCODER        = auto()
     REFACT           = auto()
     BERT             = auto()
+    MODERN_BERT      = auto()
     NOMIC_BERT       = auto()
     NOMIC_BERT_MOE   = auto()
     NEO_BERT         = auto()
@@ -747,6 +749,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.STARCODER:        "starcoder",
     MODEL_ARCH.REFACT:           "refact",
     MODEL_ARCH.BERT:             "bert",
+    MODEL_ARCH.MODERN_BERT:      "modern-bert",
     MODEL_ARCH.NOMIC_BERT:       "nomic-bert",
     MODEL_ARCH.NOMIC_BERT_MOE:   "nomic-bert-moe",
     MODEL_ARCH.NEO_BERT:         "neo-bert",
@@ -1367,6 +1370,19 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.CLS,
         MODEL_TENSOR.CLS_OUT,
     ],
+    MODEL_ARCH.MODERN_BERT: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.TOKEN_EMBD_NORM,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.ATTN_QKV,
+        MODEL_TENSOR.FFN_UP,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.CLS,
+        MODEL_TENSOR.CLS_OUT,
+    ],
     MODEL_ARCH.NOMIC_BERT: [
         MODEL_TENSOR.TOKEN_EMBD,
         MODEL_TENSOR.TOKEN_EMBD_NORM,
index 9e6ff3ac7779c6751242c8c0d39fc53b4677fbfe..6a4a504f8dcb8deaa26e31b26b52076840d269d9 100644 (file)
@@ -774,8 +774,12 @@ class GGUFWriter:
     def add_shared_kv_layers(self, value: int) -> None:
         self.add_uint32(Keys.Attention.SHARED_KV_LAYERS.format(arch=self.arch), value)
 
-    def add_sliding_window_pattern(self, value: Sequence[bool]) -> None:
-        self.add_array(Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch), value)
+    def add_sliding_window_pattern(self, value: int | Sequence[bool]) -> None:
+        key = Keys.Attention.SLIDING_WINDOW_PATTERN.format(arch=self.arch)
+        if isinstance(value, int):
+            self.add_uint32(key, value)
+        else:
+            self.add_array(key, value)
 
     def add_dense_features_dims(self, dense:str, in_f:int, out_f:int) -> None:
         self.add_uint32(Keys.LLM.DENSE_FEAT_IN_SIZE.format(arch=self.arch, dense=dense), in_f)
@@ -886,6 +890,9 @@ class GGUFWriter:
     def add_value_residual_mix_lora_rank(self, length: int) -> None:
         self.add_uint32(Keys.Attention.VALUE_RESIDUAL_MIX_LORA_RANK.format(arch=self.arch), length)
 
+    def add_rope_freq_base_swa(self, value: float) -> None:
+        self.add_float32(Keys.Rope.FREQ_BASE_SWA.format(arch=self.arch), value)
+
     def add_gate_lora_rank(self, length: int) -> None:
         self.add_uint32(Keys.Attention.GATE_LORA_RANK.format(arch=self.arch), length)
 
index 301aafa9102821f714dc35eeb9f084332714a541..276720fcde9ea18578789c110002b03ac952d28d 100644 (file)
@@ -17,6 +17,7 @@ class TensorNameMap:
             "embed_tokens",                              # embeddinggemma
             "tok_embeddings",                            # llama-pth
             "embeddings.word_embeddings",                # bert nomic-bert
+            "embeddings.tok_embeddings",                 # modern-bert
             "language_model.embedding.word_embeddings",  # persimmon
             "wte",                                       # gpt2
             "transformer.embd.wte",                      # phi2
@@ -46,6 +47,7 @@ class TensorNameMap:
         MODEL_TENSOR.TOKEN_EMBD_NORM: (
             "word_embeddings_layernorm",  # bloom
             "embeddings.LayerNorm",       # bert
+            "embeddings.norm",            # modern-bert
             "emb_ln",                     # nomic-bert
             "transformer.norm",           # openelm
             "rwkv.blocks.0.pre_ln",       # rwkv
@@ -75,6 +77,7 @@ class TensorNameMap:
             "head.out",                  # wavtokenizer
             "lm_head",                   # llama4
             "model.transformer.ff_out",  # llada
+            "head.decoder",              # modern-bert
         ),
         MODEL_TENSOR.DENSE_2_OUT: (
             "dense_2_out",  # embeddinggemma
@@ -104,6 +107,7 @@ class TensorNameMap:
             "backbone.final_layer_norm",               # wavtokenizer
             "model.norm",                              # llama4
             "model.transformer.ln_f",                  # llada
+            "final_norm",                              # modern-bert
             "model.norm",                              # cogvlm
         ),
 
@@ -151,6 +155,7 @@ class TensorNameMap:
             "model.layers.{bid}.input_layernorm",                   # llama4
             "layers.{bid}.input_layernorm",                         # embeddinggemma
             "transformer_encoder.{bid}.attention_norm",             # neobert
+            "layers.{bid}.attn_norm",                               # modern-bert
             "model.layers.{bid}.operator_norm",                     # lfm2
             "model.transformer.blocks.{bid}.attn_norm",             # llada
             "layers.{bid}.input_layernorm",                         # qwen3-embedding
@@ -187,6 +192,7 @@ class TensorNameMap:
             "encoder.layers.{bid}.self_attention.query_key_value",                 # chatglm
             "transformer.layers.{bid}.attn.qkv_proj",                              # openelm
             "transformer_encoder.{bid}.qkv",                                       # neobert
+            "layers.{bid}.attn.Wqkv",                                              # modern-bert
             "model.layers.{bid}.self_attn.language_expert_query_key_value",        # cogvlm
         ),
 
@@ -261,6 +267,7 @@ class TensorNameMap:
             "model.layers.{bid}.self_attn.linear_attn",                     # deci
             "layers.{bid}.attention.wo",                                    # llama-pth
             "encoder.layer.{bid}.attention.output.dense",                   # bert
+            "layers.{bid}.attn.Wo",                                         # modern-bert
             "transformer.layer.{bid}.attention.out_lin",                    # distillbert
             "transformer.h.{bid}.attn.out_proj",                            # gpt-j
             "language_model.encoder.layers.{bid}.self_attention.dense",     # persimmon
@@ -344,6 +351,7 @@ class TensorNameMap:
             "layers.{bid}.post_attention_layernorm",                         # qwen3-embedding
             "model.layers.{bid}.feedforward_layernorm",                      # apertus
             "model.layers.{bid}.pre_mlp_layernorm",                          # kormo
+            "layers.{bid}.mlp_norm"                                          # modern-bert
         ),
 
         # Pre feed-forward norm
@@ -407,6 +415,7 @@ class TensorNameMap:
             "layers.{bid}.mlp.up_proj",                               # embeddinggemma
             "layers.{bid}.feed_forward.w3",                           # llama-pth
             "encoder.layer.{bid}.intermediate.dense",                 # bert
+            "layers.{bid}.mlp.Wi",                                    # modern-bert
             "transformer.layer.{bid}.ffn.lin1",                       # distillbert
             "transformer.h.{bid}.mlp.fc_in",                          # gpt-j
             "transformer.h.{bid}.mlp.linear_3",                       # refact
@@ -521,6 +530,7 @@ class TensorNameMap:
             "layers.{bid}.mlp.down_proj",                             # embeddinggemma
             "layers.{bid}.feed_forward.w2",                           # llama-pth
             "encoder.layer.{bid}.output.dense",                       # bert
+            "layers.{bid}.mlp.Wo",                                    # modern-bert
             "transformer.layer.{bid}.ffn.lin2",                       # distillbert
             "transformer.h.{bid}.mlp.fc_out",                         # gpt-j
             "language_model.encoder.layers.{bid}.mlp.dense_4h_to_h",  # persimmon
@@ -1122,6 +1132,7 @@ class TensorNameMap:
             "classifier.dense", # roberta
             "pre_classifier",   # distillbert
             "dense",            # neobert
+            "head.dense",       # modern-bert
         ),
 
         MODEL_TENSOR.CLS_OUT: (
index 4192af7c0c3b3f5b5948818540b9ee210c513ccb..4ca897491607a90eab62e05c2bf4b9d5ee47e8db 100644 (file)
@@ -90,6 +90,7 @@ add_library(llama
             models/mamba.cpp
             models/minicpm3.cpp
             models/minimax-m2.cpp
+            models/modern-bert.cpp
             models/mpt.cpp
             models/nemotron-h.cpp
             models/nemotron.cpp
index d0eaf317f77c88cf77f966a4ba0cb9f4b67202a7..80f44ae1bfe78f2c3461aed0eb23076f4dfaacba 100644 (file)
@@ -20,6 +20,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_STARCODER,        "starcoder"        },
     { LLM_ARCH_REFACT,           "refact"           },
     { LLM_ARCH_BERT,             "bert"             },
+    { LLM_ARCH_MODERN_BERT,      "modern-bert"      },
     { LLM_ARCH_NOMIC_BERT,       "nomic-bert"       },
     { LLM_ARCH_NOMIC_BERT_MOE,   "nomic-bert-moe"   },
     { LLM_ARCH_NEO_BERT,         "neo-bert"         },
@@ -204,6 +205,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_ATTENTION_GATE_LORA_RANK,               "%s.attention.gate_lora_rank"               },
     { LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,       "%s.attention.relative_buckets_count"       },
     { LLM_KV_ATTENTION_SLIDING_WINDOW,               "%s.attention.sliding_window"               },
+    { LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN,       "%s.attention.sliding_window_pattern"       },
     { LLM_KV_ATTENTION_SCALE,                        "%s.attention.scale"                        },
     { LLM_KV_ATTENTION_OUTPUT_SCALE,                 "%s.attention.output_scale"                 },
     { LLM_KV_ATTENTION_TEMPERATURE_LENGTH,           "%s.attention.temperature_length"           },
@@ -214,6 +216,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_ROPE_DIMENSION_COUNT,          "%s.rope.dimension_count"                 },
     { LLM_KV_ROPE_DIMENSION_SECTIONS,       "%s.rope.dimension_sections"              },
     { LLM_KV_ROPE_FREQ_BASE,                "%s.rope.freq_base"                       },
+    { LLM_KV_ROPE_FREQ_BASE_SWA,            "%s.rope.freq_base_swa"                   },
     { LLM_KV_ROPE_SCALE_LINEAR,             "%s.rope.scale_linear"                    },
     { LLM_KV_ROPE_SCALING_TYPE,             "%s.rope.scaling.type"                    },
     { LLM_KV_ROPE_SCALING_FACTOR,           "%s.rope.scaling.factor"                  },
@@ -778,6 +781,20 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
                 LLM_TENSOR_CLS,
                 LLM_TENSOR_CLS_OUT,
             };
+        case LLM_ARCH_MODERN_BERT:
+            return {
+                LLM_TENSOR_TOKEN_EMBD,
+                LLM_TENSOR_TOKEN_EMBD_NORM,
+                LLM_TENSOR_OUTPUT_NORM,
+                LLM_TENSOR_ATTN_NORM,
+                LLM_TENSOR_ATTN_OUT,
+                LLM_TENSOR_ATTN_QKV,
+                LLM_TENSOR_FFN_DOWN,
+                LLM_TENSOR_FFN_UP,
+                LLM_TENSOR_FFN_NORM,
+                LLM_TENSOR_CLS,
+                LLM_TENSOR_CLS_OUT,
+            };
         case LLM_ARCH_JINA_BERT_V2:
             return {
                 LLM_TENSOR_TOKEN_EMBD,
index 6cbf9b1f8961988b2a01d20a5dd57730c82d04cd..a53bc39d1833914cf85985c99387b70efa3db9f6 100644 (file)
@@ -24,6 +24,7 @@ enum llm_arch {
     LLM_ARCH_STARCODER,
     LLM_ARCH_REFACT,
     LLM_ARCH_BERT,
+    LLM_ARCH_MODERN_BERT,
     LLM_ARCH_NOMIC_BERT,
     LLM_ARCH_NOMIC_BERT_MOE,
     LLM_ARCH_NEO_BERT,
@@ -208,6 +209,7 @@ enum llm_kv {
     LLM_KV_ATTENTION_GATE_LORA_RANK,
     LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
     LLM_KV_ATTENTION_SLIDING_WINDOW,
+    LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN,
     LLM_KV_ATTENTION_SCALE,
     LLM_KV_ATTENTION_OUTPUT_SCALE,
     LLM_KV_ATTENTION_TEMPERATURE_LENGTH,
@@ -218,6 +220,7 @@ enum llm_kv {
     LLM_KV_ROPE_DIMENSION_COUNT,
     LLM_KV_ROPE_DIMENSION_SECTIONS,
     LLM_KV_ROPE_FREQ_BASE,
+    LLM_KV_ROPE_FREQ_BASE_SWA,
     LLM_KV_ROPE_SCALE_LINEAR,
     LLM_KV_ROPE_SCALING_TYPE,
     LLM_KV_ROPE_SCALING_FACTOR,
index 33a76dba4017d93cadfed3193c0cbe4a626ff5b5..5003b4fbf5301b46d4d3a4fcc63b96bbb6874474 100644 (file)
@@ -462,6 +462,29 @@ namespace GGUFMeta {
         return get_key_or_arr(llm_kv(kid), result, n, required);
     }
 
+    bool llama_model_loader::get_key_or_arr(enum llm_kv kid, uint32_t & result, bool required) {
+        const std::string key = llm_kv(kid);
+
+        const int id = gguf_find_key(meta.get(), key.c_str());
+
+        if (id < 0) {
+            if (required) {
+                throw std::runtime_error(format("key not found in model: %s", key.c_str()));
+            }
+            return false;
+        }
+
+        // throw and error if type is an array
+        if (gguf_get_kv_type(meta.get(), id) == GGUF_TYPE_ARRAY) {
+            if (required) {
+                throw std::runtime_error(format("expected scalar, found array for key: %s", key.c_str()));
+            }
+            return false;
+        }
+
+        return get_key(key, result, required);
+    }
+
     // TODO: this is not very clever - figure out something better
     template bool llama_model_loader::get_key_or_arr<std::array<int, 4>>(enum llm_kv kid, std::array<int, 4> & result, uint32_t n, bool required);
     template bool llama_model_loader::get_key_or_arr<std::array<uint32_t, 512>>(enum llm_kv kid, std::array<uint32_t, 512> & result, uint32_t n, bool required);
index 0380c92fde0e3d9801832142701a0f9c54d90e43..d13299ad3f12819174826a8e732eee4f08ac5f80 100644 (file)
@@ -131,6 +131,8 @@ struct llama_model_loader {
     template<typename T>
     bool get_key_or_arr(enum llm_kv kid, T & result, uint32_t n, bool required = true);
 
+    bool get_key_or_arr(enum llm_kv kid, uint32_t & result, bool required = true);
+
     std::string get_arch_name() const;
 
     enum llm_arch get_arch() const;
index d2270e8f2daecff84b7c4df48fde3931f706700c..87fefd576cb0619643d1365984f987ca3e1dafc0 100644 (file)
@@ -31,12 +31,14 @@ const char * llm_type_name(llm_type type) {
         case LLM_TYPE_17M:           return "17M";
         case LLM_TYPE_22M:           return "22M";
         case LLM_TYPE_33M:           return "33M";
+        case LLM_TYPE_47M:           return "47M";
         case LLM_TYPE_60M:           return "60M";
         case LLM_TYPE_70M:           return "70M";
         case LLM_TYPE_80M:           return "80M";
         case LLM_TYPE_109M:          return "109M";
         case LLM_TYPE_137M:          return "137M";
         case LLM_TYPE_140M:          return "140M";
+        case LLM_TYPE_149M:          return "149M";
         case LLM_TYPE_160M:          return "160M";
         case LLM_TYPE_190M:          return "190M";
         case LLM_TYPE_220M:          return "220M";
@@ -46,6 +48,7 @@ const char * llm_type_name(llm_type type) {
         case LLM_TYPE_335M:          return "335M";
         case LLM_TYPE_350M:          return "350M";
         case LLM_TYPE_360M:          return "360M";
+        case LLM_TYPE_395M:          return "395M";
         case LLM_TYPE_410M:          return "410M";
         case LLM_TYPE_450M:          return "450M";
         case LLM_TYPE_475M:          return "475M";
@@ -875,6 +878,34 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     default: type = LLM_TYPE_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_MODERN_BERT:
+            {
+                const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
+                if (found_swa && hparams.n_swa > 0) {
+                    uint32_t swa_period = 3;
+                    hparams.swa_type = LLAMA_SWA_TYPE_SYMMETRIC;
+
+                    ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa);
+                    ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false);
+                    hparams.set_swa_pattern(swa_period);
+                } else {
+                    hparams.swa_type = LLAMA_SWA_TYPE_NONE;
+                }
+
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
+                ml.get_key(LLM_KV_ATTENTION_CAUSAL,        hparams.causal_attn);
+                ml.get_key(LLM_KV_POOLING_TYPE,            hparams.pooling_type, false);
+
+                switch (hparams.n_layer) {
+                    case 12:
+                        type = LLM_TYPE_47M; break; // granite-embedding-small
+                    case 22:
+                        type = LLM_TYPE_149M; break; // modern-bert-base
+                    case 28:
+                        type = LLM_TYPE_395M; break; // modern-bert-large
+                    default: type = LLM_TYPE_UNKNOWN;
+                }
+            } break;
         case LLM_ARCH_JINA_BERT_V2:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,    hparams.f_norm_eps);
@@ -3155,6 +3186,37 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         layer.layer_out_norm_b = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "bias", i),   {n_embd}, 0);
                     }
                 } break;
+            case LLM_ARCH_MODERN_BERT:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+                    tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
+
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+
+                    for(int i = 0; i < n_layer; ++i) {
+                        auto& layer = layers[i];
+
+                        if ( i != 0 ) {
+                            layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+                        } else{
+                            // layer 0 uses identity
+                            layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED);
+                        }
+
+
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, 3 * n_embd }, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT,   "weight", i), {n_embd, n_embd}, 0);
+
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, 2 * n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0);
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                    }
+
+                    cls       = create_tensor(tn(LLM_TENSOR_CLS,     "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED);
+                    cls_out   = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
+                    cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"),   {hparams.n_cls_out},         TENSOR_NOT_REQUIRED);
+
+                } break;
             case LLM_ARCH_NEO_BERT:
                 {
                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, 0);
@@ -7089,6 +7151,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
         case LLM_ARCH_NOMIC_BERT_MOE:
         case LLM_ARCH_NEO_BERT:
         case LLM_ARCH_WAVTOKENIZER_DEC:
+        case LLM_ARCH_MODERN_BERT:
         case LLM_ARCH_GEMMA_EMBEDDING:
         case LLM_ARCH_DREAM:
         case LLM_ARCH_LLADA:
@@ -7248,6 +7311,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
             {
                 llm = std::make_unique<llm_build_bert>(*this, params);
             } break;
+        case LLM_ARCH_MODERN_BERT:
+            {
+                llm = std::make_unique<llm_build_modern_bert<true>>(*this, params);
+            } break;
         case LLM_ARCH_NEO_BERT:
             {
                 llm = std::make_unique<llm_build_neo_bert>(*this, params);
@@ -7816,6 +7883,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_DBRX:
         case LLM_ARCH_BERT:
         case LLM_ARCH_JINA_BERT_V3:
+        case LLM_ARCH_MODERN_BERT:
         case LLM_ARCH_NOMIC_BERT:
         case LLM_ARCH_NOMIC_BERT_MOE:
         case LLM_ARCH_STABLELM:
index c6eb9531886af8125194e5336e421b7d263f5c37..7f560d462fa474049b997a2a662b58dc35b56e0d 100644 (file)
@@ -24,12 +24,14 @@ enum llm_type {
     LLM_TYPE_17M,
     LLM_TYPE_22M,
     LLM_TYPE_33M,
+    LLM_TYPE_47M,
     LLM_TYPE_60M,
     LLM_TYPE_70M,
     LLM_TYPE_80M,
     LLM_TYPE_109M,
     LLM_TYPE_137M,
     LLM_TYPE_140M,
+    LLM_TYPE_149M,
     LLM_TYPE_160M,
     LLM_TYPE_190M,
     LLM_TYPE_220M,
@@ -39,6 +41,7 @@ enum llm_type {
     LLM_TYPE_335M,
     LLM_TYPE_350M,
     LLM_TYPE_360M,
+    LLM_TYPE_395M,
     LLM_TYPE_410M,
     LLM_TYPE_450M,
     LLM_TYPE_475M,
index 7b01a2edfe1f21cd09d603829666d6fa13efe54e..cd4092ca0772ad0af6ca01ece0a6f31eddf44a82 100644 (file)
@@ -1878,7 +1878,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                     tokenizer_pre == "jina-v2-es" ||
                     tokenizer_pre == "jina-v2-de" ||
                     tokenizer_pre == "a.x-4.0" ||
-                    tokenizer_pre == "mellum") {
+                    tokenizer_pre == "mellum"  ||
+                    tokenizer_pre == "modern-bert" ) {
                 pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2;
             } else if (
                     tokenizer_pre == "jina-v1-en" ||
@@ -2528,6 +2529,13 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
             for (const auto * token : {"<unk>", "<s>", "<|endoftext|>"}) {
                 _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, false);
             }
+        } else if (_contains_any(model_name, {"modern-bert"})) {
+            if (token_to_id.count("[MASK]") == 0 ) {
+                LLAMA_LOG_WARN("%s: Mask token missing in vocab!\n", __func__);
+            }
+            else {
+                _set_token_attr("[MASK]", LLAMA_TOKEN_ATTR_LSTRIP, true);
+            }
         }
     }
 }
index ffb36acc616bef940012bc21bdb22d0a8394eced..53a5810659adcc1b37bc8871512fefa377411924 100644 (file)
@@ -327,6 +327,11 @@ struct llm_build_mistral3 : public llm_graph_context {
     llm_build_mistral3(const llama_model & model, const llm_graph_params & params);
 };
 
+template <bool iswa>
+struct llm_build_modern_bert : public llm_graph_context {
+    llm_build_modern_bert(const llama_model & model, const llm_graph_params & params);
+};
+
 struct llm_build_mpt : public llm_graph_context {
     llm_build_mpt(const llama_model & model, const llm_graph_params & params);
 };
diff --git a/src/models/modern-bert.cpp b/src/models/modern-bert.cpp
new file mode 100644 (file)
index 0000000..c7809bd
--- /dev/null
@@ -0,0 +1,126 @@
+#include "models.h"
+
+template <bool iswa>
+llm_build_modern_bert<iswa>::llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+    const int64_t n_embd_head = hparams.n_embd_head_v;
+    const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+
+    GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+    ggml_tensor * cur;
+    ggml_tensor * inpL;
+    ggml_tensor * inp_pos = build_inp_pos();
+
+    // construct input embeddings (token, type, position)
+    inpL = build_inp_embd(model.tok_embd);
+    cb(inpL, "inp_embd", -1);
+
+    // embed layer norm
+    inpL = build_norm(inpL, model.tok_norm, nullptr, LLM_NORM, -1);
+    cb(inpL, "inp_norm", -1);
+
+    ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+    auto * inp_attn = build_attn_inp_no_cache();
+
+    for (int il = 0; il < n_layer; ++il) {
+        float freq_base_l  = 0.0f;
+
+        if constexpr (iswa) {
+            freq_base_l = model.get_rope_freq_base(cparams, il);
+        } else {
+            freq_base_l = freq_base;
+        }
+
+        cur = inpL;
+
+        // attention layer norm
+        if (model.layers[il].attn_norm) {
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM, il);
+            cb(cur, "attn_norm", il);
+        }
+
+        // self attention
+        cur = build_lora_mm(model.layers[il].wqkv, cur);
+        cb(cur, "wqkv", il);
+
+        const size_t type_size = ggml_type_size(cur->type);
+
+        ggml_tensor * Qcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head,    n_tokens, n_embd_head*type_size, cur->nb[1], 0*type_size*(n_embd));
+        ggml_tensor * Kcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*type_size, cur->nb[1], 1*type_size*(n_embd));
+        ggml_tensor * Vcur = ggml_view_3d(ctx0, cur, n_embd_head, n_head_kv, n_tokens, n_embd_head*type_size, cur->nb[1], 1*type_size*(n_embd + n_embd_gqa));
+
+        // RoPE
+        Qcur = ggml_rope_ext(
+                ctx0, Qcur, inp_pos, nullptr,
+                n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale,
+                ext_factor, attn_factor, beta_fast, beta_slow
+                );
+
+        Kcur = ggml_rope_ext(
+                ctx0, Kcur, inp_pos, nullptr,
+                n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale,
+                ext_factor, attn_factor, beta_fast, beta_slow
+                );
+
+        cb(Qcur, "Qcur", il);
+        cb(Kcur, "Kcur", il);
+        cb(Vcur, "Vcur", il);
+
+        cur = build_attn(inp_attn,
+                    model.layers[il].wo, nullptr,
+                    Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+        cb(cur, "kqv_out", il);
+
+        if (il == n_layer - 1 && inp_out_ids) {
+            cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+            inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+        }
+
+        // re-add the layer input
+        ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+        cb(ffn_inp, "ffn_inp", il);
+
+        // attention layer norm
+        cur = build_norm(ffn_inp,
+                model.layers[il].ffn_norm, NULL,
+                LLM_NORM, il);
+        cb(cur, "ffn_norm", il);
+
+        cur = build_ffn(cur,
+                model.layers[il].ffn_up,   NULL, NULL,
+                NULL,                      NULL, NULL,
+                model.layers[il].ffn_down, NULL, NULL,
+                NULL,
+                LLM_FFN_GEGLU, LLM_FFN_SEQ, il);
+
+        // attentions bypass the intermediate layer
+        cur = ggml_add(ctx0, cur, ffn_inp);
+
+        // input for next layer
+        inpL = cur;
+    }
+
+    cur = inpL;
+
+    cur = build_norm(cur,
+            model.output_norm, NULL,
+            LLM_NORM, -1);
+    cb(cur, "final_norm_out", -1);
+
+    if (hparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
+        // extracting cls token
+        cur = ggml_view_1d(ctx0, cur, hparams.n_embd, 0);
+        cb(cur, "cls_pooled_embd", -1);
+    }
+
+    cb(cur, "res_embd", -1);
+    res->t_embd = cur;
+    ggml_build_forward_expand(gf, cur);
+}
+
+// Explicit template instantiations
+template struct llm_build_modern_bert<false>;
+template struct llm_build_modern_bert<true>;