]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
model : Nomic Embed Text V2 with Mixture-of-Experts (MoE) architecture (#12466)
authorAT <redacted>
Mon, 28 Apr 2025 19:52:15 +0000 (15:52 -0400)
committerGitHub <redacted>
Mon, 28 Apr 2025 19:52:15 +0000 (22:52 +0300)
* Nomic Embed Text V2 with Mixture-of-Experts (MoE) architecture

- Adds MoE-based embedding model supporting multilingual embeddings.
- Selects architecture variant based on hyperparameter detection (MoE layers).
- Removes unnecessary subclass initialization checks for clarity.

https://www.nomic.ai/blog/posts/nomic-embed-text-v2

Co-authored-by: Jared Van Bortel <redacted>
* fix tokenizer

* don't rename this tensor

---------

Co-authored-by: Jared Van Bortel <redacted>
convert_hf_to_gguf.py
gguf-py/gguf/constants.py
gguf-py/gguf/gguf_writer.py
gguf-py/gguf/tensor_mapping.py
src/llama-arch.cpp
src/llama-arch.h
src/llama-graph.cpp
src/llama-hparams.h
src/llama-model.cpp

index d4fec408dd2020c2f05f769a284806e0809cc7e9..b9cea7e4699c6ebf799f88902a52975452737d8c 100755 (executable)
@@ -78,7 +78,7 @@ class ModelBase:
     # subclasses should define this!
     model_arch: gguf.MODEL_ARCH
 
-    def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool = False,
+    def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False,
                  use_temp_file: bool = False, eager: bool = False,
                  metadata_override: Path | None = None, model_name: str | None = None,
                  split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
@@ -454,13 +454,6 @@ class ModelBase:
 
 
 class TextModel(ModelBase):
-    @classmethod
-    def __init_subclass__(cls):
-        # can't use an abstract property, because overriding it without type errors
-        # would require using decorated functions instead of simply defining the property
-        if "model_arch" not in cls.__dict__:
-            raise TypeError(f"Missing property 'model_arch' for {cls.__name__!r}")
-
     def set_vocab(self):
         self._set_vocab_gpt2()
 
@@ -3373,14 +3366,7 @@ class BertModel(TextModel):
 
         return [(self.map_tensor_name(name), data_torch)]
 
-
-@ModelBase.register("RobertaModel")
-class RobertaModel(BertModel):
-    model_arch = gguf.MODEL_ARCH.BERT
-
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-
+    def _xlmroberta_tokenizer_init(self) -> None:
         # we need the pad_token_id to know how to chop down position_embd matrix
         if (pad_token_id := self.hparams.get("pad_token_id")) is not None:
             self._position_offset = 1 + pad_token_id
@@ -3389,82 +3375,7 @@ class RobertaModel(BertModel):
         else:
             self._position_offset = None
 
-    def set_vocab(self):
-        """Support BPE tokenizers for roberta models"""
-        bpe_tok_path = self.dir_model / "tokenizer.json"
-        if bpe_tok_path.exists():
-            self._set_vocab_gpt2()
-            self.gguf_writer.add_add_bos_token(True)
-            self.gguf_writer.add_add_eos_token(True)
-
-            # we need this to validate the size of the token_type embeddings
-            # though currently we are passing all zeros to the token_type embeddings
-            # "Sequence A" or "Sequence B"
-            self.gguf_writer.add_token_type_count(self.hparams.get("type_vocab_size", 1))
-
-        else:
-            return super().set_vocab()
-
-    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
-        # if name starts with "roberta.", remove the prefix
-        # e.g. https://huggingface.co/BAAI/bge-reranker-v2-m3/tree/main
-        if name.startswith("roberta."):
-            name = name[8:]
-
-        # position embeddings start at pad_token_id + 1, so just chop down the weight tensor
-        if name == "embeddings.position_embeddings.weight":
-            if self._position_offset is not None:
-                data_torch = data_torch[self._position_offset:,:]
-
-        return super().modify_tensors(data_torch, name, bid)
-
-
-@ModelBase.register("NomicBertModel")
-class NomicBertModel(BertModel):
-    model_arch = gguf.MODEL_ARCH.NOMIC_BERT
-
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-
-        # the HF config claims n_ctx=8192, but it uses RoPE scaling
-        self.hparams["n_ctx"] = 2048
-
-        # SwigLU activation
-        assert self.hparams["activation_function"] == "swiglu"
-        # this doesn't do anything in the HF version
-        assert self.hparams["causal"] is False
-        # no bias tensors
-        assert self.hparams["qkv_proj_bias"] is False
-        assert self.hparams["mlp_fc1_bias"] is False
-        assert self.hparams["mlp_fc2_bias"] is False
-        # norm at end of layer
-        assert self.hparams["prenorm"] is False
-        # standard RoPE
-        assert self.hparams["rotary_emb_fraction"] == 1.0
-        assert self.hparams["rotary_emb_interleaved"] is False
-        assert self.hparams["rotary_emb_scale_base"] is None
-
-    def set_gguf_parameters(self):
-        super().set_gguf_parameters()
-        self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
-
-
-@ModelBase.register("XLMRobertaModel", "XLMRobertaForSequenceClassification")
-class XLMRobertaModel(BertModel):
-    model_arch = gguf.MODEL_ARCH.BERT
-
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-
-        # we need the pad_token_id to know how to chop down position_embd matrix
-        if (pad_token_id := self.hparams.get("pad_token_id")) is not None:
-            self._position_offset = 1 + pad_token_id
-            if "max_position_embeddings" in self.hparams:
-                self.hparams["max_position_embeddings"] -= self._position_offset
-        else:
-            self._position_offset = None
-
-    def set_vocab(self):
+    def _xlmroberta_set_vocab(self) -> None:
         # to avoid TypeError: Descriptors cannot be created directly
         # exception when importing sentencepiece_model_pb2
         os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
@@ -3546,6 +3457,138 @@ class XLMRobertaModel(BertModel):
         self.gguf_writer.add_add_bos_token(True)
         self.gguf_writer.add_add_eos_token(True)
 
+
+@ModelBase.register("RobertaModel")
+class RobertaModel(BertModel):
+    model_arch = gguf.MODEL_ARCH.BERT
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+        # we need the pad_token_id to know how to chop down position_embd matrix
+        if (pad_token_id := self.hparams.get("pad_token_id")) is not None:
+            self._position_offset = 1 + pad_token_id
+            if "max_position_embeddings" in self.hparams:
+                self.hparams["max_position_embeddings"] -= self._position_offset
+        else:
+            self._position_offset = None
+
+    def set_vocab(self):
+        """Support BPE tokenizers for roberta models"""
+        bpe_tok_path = self.dir_model / "tokenizer.json"
+        if bpe_tok_path.exists():
+            self._set_vocab_gpt2()
+            self.gguf_writer.add_add_bos_token(True)
+            self.gguf_writer.add_add_eos_token(True)
+
+            # we need this to validate the size of the token_type embeddings
+            # though currently we are passing all zeros to the token_type embeddings
+            # "Sequence A" or "Sequence B"
+            self.gguf_writer.add_token_type_count(self.hparams.get("type_vocab_size", 1))
+
+        else:
+            return super().set_vocab()
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        # if name starts with "roberta.", remove the prefix
+        # e.g. https://huggingface.co/BAAI/bge-reranker-v2-m3/tree/main
+        if name.startswith("roberta."):
+            name = name[8:]
+
+        # position embeddings start at pad_token_id + 1, so just chop down the weight tensor
+        if name == "embeddings.position_embeddings.weight":
+            if self._position_offset is not None:
+                data_torch = data_torch[self._position_offset:,:]
+
+        return super().modify_tensors(data_torch, name, bid)
+
+
+@ModelBase.register("NomicBertModel")
+class NomicBertModel(BertModel):
+    def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, **kwargs: Any):
+        hparams = kwargs.pop("hparams", None)
+        if hparams is None:
+            hparams = ModelBase.load_hparams(dir_model)
+
+        self.is_moe = bool(hparams.get("moe_every_n_layers"))
+        self.model_arch = gguf.MODEL_ARCH.NOMIC_BERT_MOE if self.is_moe else gguf.MODEL_ARCH.NOMIC_BERT
+
+        super().__init__(dir_model, ftype, fname_out, hparams=hparams, **kwargs)
+
+        self._tokenizer_is_xlmroberta = self._is_tokenizer_xlmroberta()
+        if self._tokenizer_is_xlmroberta:
+            self._xlmroberta_tokenizer_init()
+
+        # the HF config claims n_ctx=8192, but it uses RoPE scaling
+        self.hparams["n_ctx"] = 2048
+
+        assert self.hparams["activation_function"] == "gelu" if self.is_moe else "swiglu"
+
+        # this doesn't do anything in the HF version
+        assert self.hparams["causal"] is False
+        # no bias tensors unless MoE
+        assert self.hparams["qkv_proj_bias"] == self.is_moe
+        assert self.hparams["mlp_fc1_bias"]  == self.is_moe
+        assert self.hparams["mlp_fc2_bias"]  == self.is_moe
+
+        # norm at end of layer
+        assert self.hparams["prenorm"] is False
+        # standard RoPE
+        assert self.hparams["rotary_emb_fraction"] == 1.0
+        assert self.hparams["rotary_emb_interleaved"] is False
+        assert self.hparams["rotary_emb_scale_base"] is None
+
+    def set_vocab(self) -> None:
+        if self._tokenizer_is_xlmroberta:
+            return self._xlmroberta_set_vocab()
+        return super().set_vocab()
+
+    def modify_tensors(self, data_torch: torch.Tensor, name: str, bid: int | None) -> Iterable[tuple[str, torch.Tensor]]:
+        # If the tensor is an experts bias tensor, skip it by returning an empty list.
+        if "mlp.experts.bias" in name:
+            return []  # Explicitly return an empty list.
+
+        if "mlp.experts.mlp.w1" in name:
+            data_torch = data_torch.view(self.hparams["num_experts"], self.hparams["n_inner"], self.hparams["n_embd"])
+            name += ".weight"
+
+        if "mlp.experts.mlp.w2" in name:
+            data_torch = data_torch.view(self.hparams["num_experts"], self.hparams["n_inner"], self.hparams["n_embd"])
+            data_torch = data_torch.transpose(1, 2)
+            name += ".weight"
+
+        return [(self.map_tensor_name(name), data_torch)]
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
+        if self.is_moe:
+            self.gguf_writer.add_moe_every_n_layers(self.hparams["moe_every_n_layers"])
+            self.gguf_writer.add_expert_count(self.hparams["num_experts"])
+            self.gguf_writer.add_expert_used_count(self.hparams["moe_top_k"])
+
+    def _is_tokenizer_xlmroberta(self) -> bool:
+        with open(self.dir_model / "tokenizer.json") as f:
+            tokenizer_json = json.load(f)
+        toktyp = tokenizer_json["model"]["type"]
+        if toktyp == "Unigram":
+            return True
+        if toktyp == "WordPiece":
+            return False
+        raise ValueError(f"unknown tokenizer: {toktyp}")
+
+
+@ModelBase.register("XLMRobertaModel", "XLMRobertaForSequenceClassification")
+class XLMRobertaModel(BertModel):
+    model_arch = gguf.MODEL_ARCH.BERT
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self._xlmroberta_tokenizer_init()
+
+    def set_vocab(self):
+        self._xlmroberta_set_vocab()
+
     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
         # if name starts with "roberta.", remove the prefix
         # e.g. https://huggingface.co/BAAI/bge-reranker-v2-m3/tree/main
index b81017b1425830794cbf934a01877042c49d8d2b..326ccdb071a79f225945ebc6d0b83ebe278b0d67 100644 (file)
@@ -104,6 +104,7 @@ class Keys:
         EXPERT_WEIGHTS_SCALE              = "{arch}.expert_weights_scale"
         EXPERT_WEIGHTS_NORM               = "{arch}.expert_weights_norm"
         EXPERT_GATING_FUNC                = "{arch}.expert_gating_func"
+        MOE_EVERY_N_LAYERS                = "{arch}.moe_every_n_layers"
         POOLING_TYPE                      = "{arch}.pooling_type"
         LOGIT_SCALE                       = "{arch}.logit_scale"
         DECODER_START_TOKEN_ID            = "{arch}.decoder_start_token_id"
@@ -267,6 +268,7 @@ class MODEL_ARCH(IntEnum):
     REFACT           = auto()
     BERT             = auto()
     NOMIC_BERT       = auto()
+    NOMIC_BERT_MOE   = auto()
     JINA_BERT_V2     = auto()
     BLOOM            = auto()
     STABLELM         = auto()
@@ -521,6 +523,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.REFACT:           "refact",
     MODEL_ARCH.BERT:             "bert",
     MODEL_ARCH.NOMIC_BERT:       "nomic-bert",
+    MODEL_ARCH.NOMIC_BERT_MOE:   "nomic-bert-moe",
     MODEL_ARCH.JINA_BERT_V2:     "jina-bert-v2",
     MODEL_ARCH.BLOOM:            "bloom",
     MODEL_ARCH.STABLELM:         "stablelm",
@@ -960,6 +963,22 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_UP,
         MODEL_TENSOR.LAYER_OUT_NORM,
     ],
+    MODEL_ARCH.NOMIC_BERT_MOE: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.TOKEN_EMBD_NORM,
+        MODEL_TENSOR.TOKEN_TYPES,
+        MODEL_TENSOR.POS_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.ATTN_OUT_NORM,
+        MODEL_TENSOR.ATTN_QKV,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+        MODEL_TENSOR.FFN_GATE_INP,
+        MODEL_TENSOR.FFN_DOWN_EXP,
+        MODEL_TENSOR.FFN_UP_EXP,
+        MODEL_TENSOR.LAYER_OUT_NORM,
+    ],
     MODEL_ARCH.JINA_BERT_V2: [
         MODEL_TENSOR.TOKEN_EMBD,
         MODEL_TENSOR.TOKEN_EMBD_NORM,
index 48e9a470b78d67dde6ea4973b17749af300ab424..f22a6d4a3472be0b13c89fbdc67068041f94071d 100644 (file)
@@ -728,6 +728,9 @@ class GGUFWriter:
     def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None:
         self.add_uint32(Keys.LLM.EXPERT_GATING_FUNC.format(arch=self.arch), value.value)
 
+    def add_moe_every_n_layers(self, value: int) -> None:
+        self.add_uint32(Keys.LLM.MOE_EVERY_N_LAYERS.format(arch=self.arch), value)
+
     def add_swin_norm(self, value: bool) -> None:
         self.add_bool(Keys.LLM.SWIN_NORM.format(arch=self.arch), value)
 
index 1d70551973b01144247de6236d358579ec2b531f..311d1ff69c7999d25db65de948adde86e4ee194e 100644 (file)
@@ -290,6 +290,7 @@ class TensorNameMap:
             "transformer.blocks.{bid}.ffn.router.layer",        # dbrx
             "model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
             "language_model.model.layers.{bid}.feed_forward.router", # llama4
+            "encoder.layers.{bid}.mlp.router.layer",            # nomic-bert-moe
         ),
 
         MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
@@ -322,6 +323,7 @@ class TensorNameMap:
             "model.layers.layers.{bid}.mlp.up_proj",                  # plamo
             "model.layers.{bid}.feed_forward.w3",                     # internlm2
             "encoder.layers.{bid}.mlp.fc11",                          # nomic-bert
+            "encoder.layers.{bid}.mlp.fc1",                           # nomic-bert-moe
             "model.layers.{bid}.mlp.c_fc",                            # starcoder2
             "encoder.layer.{bid}.mlp.gated_layers_v",                 # jina-bert-v2
             "model.layers.{bid}.residual_mlp.w3",                     # arctic
@@ -337,6 +339,7 @@ class TensorNameMap:
             "model.layers.{bid}.mlp.experts.up_proj",         # qwen2moe olmoe (merged)
             "model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged)
             "language_model.model.layers.{bid}.feed_forward.experts.up_proj", # llama4
+            "encoder.layers.{bid}.mlp.experts.mlp.w1",        # nomic-bert-moe
         ),
 
         MODEL_TENSOR.FFN_UP_SHEXP: (
@@ -418,6 +421,7 @@ class TensorNameMap:
             "model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
             "model.layers.{bid}.block_sparse_moe.experts.w2",    # phimoe (merged)
             "language_model.model.layers.{bid}.feed_forward.experts.down_proj", # llama4
+            "encoder.layers.{bid}.mlp.experts.mlp.w2",           # nomic-bert-moe
         ),
 
         MODEL_TENSOR.FFN_DOWN_SHEXP: (
index 62e1480bb5881aea613182ba522c890ac3511d2d..f2bc8ca76850278ce2f4b320300e503dec4158cc 100644 (file)
@@ -19,6 +19,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_REFACT,           "refact"           },
     { LLM_ARCH_BERT,             "bert"             },
     { LLM_ARCH_NOMIC_BERT,       "nomic-bert"       },
+    { LLM_ARCH_NOMIC_BERT_MOE,   "nomic-bert-moe"   },
     { LLM_ARCH_JINA_BERT_V2,     "jina-bert-v2"     },
     { LLM_ARCH_BLOOM,            "bloom"            },
     { LLM_ARCH_STABLELM,         "stablelm"         },
@@ -106,6 +107,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_EXPERT_WEIGHTS_SCALE,              "%s.expert_weights_scale"              },
     { LLM_KV_EXPERT_WEIGHTS_NORM,               "%s.expert_weights_norm"               },
     { LLM_KV_EXPERT_GATING_FUNC,                "%s.expert_gating_func"                },
+    { LLM_KV_MOE_EVERY_N_LAYERS,                "%s.moe_every_n_layers"                },
     { LLM_KV_POOLING_TYPE,                      "%s.pooling_type"                      },
     { LLM_KV_LOGIT_SCALE,                       "%s.logit_scale"                       },
     { LLM_KV_DECODER_START_TOKEN_ID,            "%s.decoder_start_token_id"            },
@@ -472,6 +474,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
         },
     },
+    {
+        LLM_ARCH_NOMIC_BERT_MOE,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
+            { LLM_TENSOR_TOKEN_TYPES,     "token_types" },
+            { LLM_TENSOR_ATTN_OUT_NORM,   "blk.%d.attn_output_norm" },
+            { LLM_TENSOR_ATTN_QKV,        "blk.%d.attn_qkv" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_LAYER_OUT_NORM,  "blk.%d.layer_output_norm" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+            { LLM_TENSOR_FFN_GATE_INP,    "blk.%d.ffn_gate_inp" },
+            { LLM_TENSOR_FFN_DOWN_EXPS,   "blk.%d.ffn_down_exps" },
+            { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
+        },
+    },
     {
         LLM_ARCH_JINA_BERT_V2,
         {
index 98ca00a1bd0b0aa3a3a2bca52e21341bd062b05c..41a023da3da6ee31a4637f84cac401b4f5de23b4 100644 (file)
@@ -23,6 +23,7 @@ enum llm_arch {
     LLM_ARCH_REFACT,
     LLM_ARCH_BERT,
     LLM_ARCH_NOMIC_BERT,
+    LLM_ARCH_NOMIC_BERT_MOE,
     LLM_ARCH_JINA_BERT_V2,
     LLM_ARCH_BLOOM,
     LLM_ARCH_STABLELM,
@@ -110,6 +111,7 @@ enum llm_kv {
     LLM_KV_EXPERT_WEIGHTS_SCALE,
     LLM_KV_EXPERT_WEIGHTS_NORM,
     LLM_KV_EXPERT_GATING_FUNC,
+    LLM_KV_MOE_EVERY_N_LAYERS,
     LLM_KV_POOLING_TYPE,
     LLM_KV_LOGIT_SCALE,
     LLM_KV_DECODER_START_TOKEN_ID,
index e6595fb18bc5b46d35fab2b5967ac83eab42c159..2706ea2635444f4f180411328d566372a08e1d80 100644 (file)
@@ -925,28 +925,35 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
     ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
     cb(up, "ffn_moe_up", il);
 
-    ggml_tensor * gate = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
-    cb(gate, "ffn_moe_gate", il);
+    ggml_tensor * experts = nullptr;
+    if (gate_exps) {
+        cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
+        cb(cur, "ffn_moe_gate", il);
+    } else {
+        cur = up;
+    }
 
     switch (type_op) {
         case LLM_FFN_SILU:
             {
-                gate = ggml_silu(ctx0, gate);
-                cb(gate, "ffn_moe_silu", il);
+                cur = ggml_silu(ctx0, cur);
+                cb(cur, "ffn_moe_silu", il);
             } break;
         case LLM_FFN_GELU:
             {
-                gate = ggml_gelu(ctx0, gate);
-                cb(gate, "ffn_moe_gelu", il);
+                cur = ggml_gelu(ctx0, cur);
+                cb(cur, "ffn_moe_gelu", il);
             } break;
         default:
             GGML_ABORT("fatal error");
     }
 
-    ggml_tensor * par = ggml_mul(ctx0, up, gate); // [n_ff, n_expert_used, n_tokens]
-    cb(par, "ffn_moe_gate_par", il);
+    if (gate_exps) {
+        cur = ggml_mul(ctx0, cur, up); // [n_ff, n_expert_used, n_tokens]
+        cb(cur, "ffn_moe_gate_par", il);
+    }
 
-    ggml_tensor * experts = build_lora_mm_id(down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
+    experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
     cb(experts, "ffn_moe_down", il);
 
     if (!weight_before_ffn) {
index 80fcd65df0d3c5c38405474d1dbb1f680949ca5b..7ee6a5b75ad1ef66a0e3d21a514e257ae6dcecf1 100644 (file)
@@ -66,6 +66,7 @@ struct llama_hparams {
     float    expert_weights_scale = 0.0;
     bool     expert_weights_norm  = false;
     uint32_t expert_gating_func   = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE;
+    uint32_t moe_every_n_layers   = 0;
 
     float f_norm_eps;
     float f_norm_rms_eps;
index df2791002e9f9fcea66c413bc6a636743fecf04d..2ec55d55a37be46119fb57ac2bc3c425a573f247 100644 (file)
@@ -695,10 +695,12 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                 }
             } break;
         case LLM_ARCH_NOMIC_BERT:
+        case LLM_ARCH_NOMIC_BERT_MOE:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,    hparams.f_norm_eps);
                 ml.get_key(LLM_KV_ATTENTION_CAUSAL,           hparams.causal_attn);
                 ml.get_key(LLM_KV_POOLING_TYPE,               hparams.pooling_type);
+                ml.get_key(LLM_KV_MOE_EVERY_N_LAYERS,         hparams.moe_every_n_layers, 0);
 
                 if (hparams.n_layer == 12 && hparams.n_embd == 768) {
                     type = LLM_TYPE_137M;
@@ -2057,6 +2059,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                 } break;
             case LLM_ARCH_BERT:
             case LLM_ARCH_NOMIC_BERT:
+            case LLM_ARCH_NOMIC_BERT_MOE:
                 {
                     tok_embd     = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD,  "weight"), {n_embd, n_vocab}, 0);
                     type_embd    = create_tensor(tn(LLM_TENSOR_TOKEN_TYPES, "weight"), {n_embd, n_token_types}, 0);
@@ -2090,20 +2093,31 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                             layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, 0);
                         }
 
+                        if (arch == LLM_ARCH_NOMIC_BERT_MOE) {
+                            layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, 0);
+                        }
+
                         layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT,      "weight", i), {n_embd, n_embd}, 0);
 
                         layer.attn_out_norm   = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "weight", i), {n_embd}, 0);
                         layer.attn_out_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_OUT_NORM, "bias", i),   {n_embd}, 0);
 
-                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,        "weight", i), {n_embd, n_ff}, 0);
-                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN,      "weight", i), {n_ff, n_embd}, 0);
-
-                        if (arch == LLM_ARCH_BERT) {
+                        if (hparams.moe_every_n_layers > 0 && i % hparams.moe_every_n_layers == 1) {
                             layer.bo         = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
-                            layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff}, 0);
-                            layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
+                            layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff,   n_expert}, 0);
+                            layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {  n_ff,   n_embd, n_expert}, 0);
+                            layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP,   "weight", i), {n_embd, n_expert}, 0);
                         } else {
-                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
+                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,        "weight", i), {n_embd, n_ff}, 0);
+                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN,      "weight", i), {n_ff, n_embd}, 0);
+
+                            if (arch == LLM_ARCH_BERT || arch == LLM_ARCH_NOMIC_BERT_MOE) {
+                                layer.bo         = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0);
+                                layer.ffn_up_b   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "bias", i), {n_ff}, 0);
+                                layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0);
+                            } else {
+                                layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
+                            }
                         }
 
                         layer.layer_out_norm   = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0);
@@ -5730,6 +5744,11 @@ struct llm_build_bert : public llm_graph_context {
                 cur = build_lora_mm(model.layers[il].wqkv, cur);
                 cb(cur, "wqkv", il);
 
+                if (model.arch == LLM_ARCH_NOMIC_BERT_MOE) {
+                    cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+                    cb(cur, "bqkv", il);
+                }
+
                 Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
                 Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
                 Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
@@ -5782,13 +5801,29 @@ struct llm_build_bert : public llm_graph_context {
             cb(ffn_inp, "ffn_inp", il);
 
             // feed-forward network
-            if (model.arch == LLM_ARCH_BERT) {
+            if (hparams.moe_every_n_layers > 0 && il % hparams.moe_every_n_layers == 1) {
+                // MoE branch
+                cur = build_moe_ffn(cur,
+                        model.layers[il].ffn_gate_inp,
+                        model.layers[il].ffn_up_exps,
+                        nullptr,
+                        model.layers[il].ffn_down_exps,
+                        nullptr,
+                        hparams.n_expert,
+                        hparams.n_expert_used,
+                        LLM_FFN_GELU,
+                        false, false,
+                        0.0f,
+                        LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
+                cb(cur, "ffn_moe_out", il);
+            } else if (model.arch == LLM_ARCH_BERT || model.arch == LLM_ARCH_NOMIC_BERT_MOE) {
                 cur = build_ffn(cur,
                         model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
                         NULL,                      NULL,                        NULL,
                         model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
                         NULL,
                         LLM_FFN_GELU, LLM_FFN_SEQ, il);
+                cb(cur, "ffn_out", il);
             } else if (model.arch == LLM_ARCH_JINA_BERT_V2) {
                 cur = build_ffn(cur,
                         model.layers[il].ffn_up,   NULL,                        NULL,
@@ -5796,6 +5831,7 @@ struct llm_build_bert : public llm_graph_context {
                         model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
                         NULL,
                         LLM_FFN_GELU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
             } else {
                 cur = build_ffn(cur,
                         model.layers[il].ffn_up,   NULL, NULL,
@@ -5803,8 +5839,8 @@ struct llm_build_bert : public llm_graph_context {
                         model.layers[il].ffn_down, NULL, NULL,
                         NULL,
                         LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
             }
-            cb(cur, "ffn_out", il);
 
             // attentions bypass the intermediate layer
             cur = ggml_add(ctx0, cur, ffn_inp);
@@ -12843,6 +12879,7 @@ llm_graph_result_ptr llama_model::build_graph(
         case LLM_ARCH_BERT:
         case LLM_ARCH_JINA_BERT_V2:
         case LLM_ARCH_NOMIC_BERT:
+        case LLM_ARCH_NOMIC_BERT_MOE:
             {
                 llm = std::make_unique<llm_build_bert>(*this, params, gf);
             } break;
@@ -13201,6 +13238,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_DBRX:
         case LLM_ARCH_BERT:
         case LLM_ARCH_NOMIC_BERT:
+        case LLM_ARCH_NOMIC_BERT_MOE:
         case LLM_ARCH_STABLELM:
         case LLM_ARCH_BITNET:
         case LLM_ARCH_QWEN: