]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
model: Add support for PhiMoE arch (#11003)
authorPierrick Hymbert <redacted>
Thu, 9 Jan 2025 10:21:41 +0000 (11:21 +0100)
committerGitHub <redacted>
Thu, 9 Jan 2025 10:21:41 +0000 (11:21 +0100)
* model: support phimoe

* python linter

* doc: minor

Co-authored-by: ThiloteE <redacted>
* doc: minor

Co-authored-by: ThiloteE <redacted>
* doc: add phimoe as supported model

ggml-ci

---------

Co-authored-by: ThiloteE <redacted>
README.md
convert_hf_to_gguf.py
docs/development/HOWTO-add-model.md
gguf-py/gguf/constants.py
gguf-py/gguf/tensor_mapping.py
src/llama-arch.cpp
src/llama-arch.h
src/llama-model.cpp
src/llama-model.h
src/llama.cpp

index 0126da89c9d6fe37987dd3fc99f21c8d61b4c8c2..a7101525648b0bf750119c6968bd1ef80865ca93 100644 (file)
--- a/README.md
+++ b/README.md
@@ -69,6 +69,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
 - [x] [Qwen models](https://huggingface.co/models?search=Qwen/Qwen)
 - [x] [PLaMo-13B](https://github.com/ggerganov/llama.cpp/pull/3557)
 - [x] [Phi models](https://huggingface.co/models?search=microsoft/phi)
+- [x] [PhiMoE](https://github.com/ggerganov/llama.cpp/pull/11003)
 - [x] [GPT-2](https://huggingface.co/gpt2)
 - [x] [Orion 14B](https://github.com/ggerganov/llama.cpp/pull/5118)
 - [x] [InternLM2](https://huggingface.co/models?search=internlm2)
index 01b58f97600ebe65bf48883763effe03ab250451..5562499aa4925849d6c515c93ba63a0fd5524496 100755 (executable)
@@ -2562,6 +2562,63 @@ class Phi3MiniModel(Model):
         yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32))
 
 
+@Model.register("PhiMoEForCausalLM")
+class PhiMoeModel(Phi3MiniModel):
+    model_arch = gguf.MODEL_ARCH.PHIMOE
+
+    _experts: list[dict[str, Tensor]] | None = None
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"])
+        self.gguf_writer.add_expert_count(self.hparams["num_local_experts"])
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        # process the experts separately
+        if name.find("block_sparse_moe.experts") != -1:
+            n_experts = self.hparams["num_local_experts"]
+            assert bid is not None
+
+            if self._experts is None:
+                self._experts = [{} for _ in range(self.block_count)]
+
+            self._experts[bid][name] = data_torch
+
+            if len(self._experts[bid]) >= n_experts * 3:
+                tensors: list[tuple[str, Tensor]] = []
+
+                # merge the experts into a single 3d tensor
+                for w_name in ["w1", "w2", "w3"]:
+                    datas: list[Tensor] = []
+
+                    for xid in range(n_experts):
+                        ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{w_name}.weight"
+                        datas.append(self._experts[bid][ename])
+                        del self._experts[bid][ename]
+
+                    data_torch = torch.stack(datas, dim=0)
+
+                    merged_name = f"model.layers.{bid}.block_sparse_moe.experts.{w_name}.weight"
+
+                    new_name = self.map_tensor_name(merged_name)
+
+                    tensors.append((new_name, data_torch))
+                return tensors
+            else:
+                return []
+
+        return [(self.map_tensor_name(name), data_torch)]
+
+    def prepare_tensors(self):
+        super().prepare_tensors()
+
+        if self._experts is not None:
+            # flatten `list[dict[str, Tensor]]` into `list[str]`
+            experts = [k for d in self._experts for k in d.keys()]
+            if len(experts) > 0:
+                raise ValueError(f"Unprocessed experts: {experts}")
+
+
 @Model.register("PlamoForCausalLM")
 class PlamoModel(Model):
     model_arch = gguf.MODEL_ARCH.PLAMO
index 04c5ccbbe60c3ded50ec722323ac59181366c8b7..8fcd7081130f25c29f8936a7b1e26b8101a33a82 100644 (file)
@@ -28,7 +28,7 @@ The required steps to implement for an HF model are:
 ```python
 @Model.register("MyModelForCausalLM")
 class MyModel(Model):
-    model_arch = gguf.MODEL_ARCH.GROK
+    model_arch = gguf.MODEL_ARCH.MYMODEL
 ```
 
 2. Define the layout of the GGUF tensors in [constants.py](/gguf-py/gguf/constants.py)
@@ -79,14 +79,14 @@ Depending on the model configuration, tokenizer, code and tensors layout, you wi
 - `Model#set_vocab`
 - `Model#write_tensors`
 
-NOTE: Tensor names must end with `.weight` suffix, that is the convention and several tools like `quantize` expect this to proceed the weights.
+NOTE: Tensor names must end with `.weight` or `.bias` suffixes, that is the convention and several tools like `quantize` expect this to proceed the weights.
 
 ### 2. Define the model architecture in `llama.cpp`
 
 The model params and tensors layout must be defined in `llama.cpp`:
 1. Define a new `llm_arch`
 2. Define the tensors layout in `LLM_TENSOR_NAMES`
-3. Add any non standard metadata in `llm_load_hparams`
+3. Add any non-standard metadata in `llm_load_hparams`
 4. Create the tensors for inference in `llm_load_tensors`
 5. If the model has a RoPE operation, add the rope type in `llama_rope_type`
 
@@ -96,9 +96,9 @@ NOTE: The dimensions in `ggml` are typically in the reverse order of the `pytorc
 
 This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `llama_build_graph`.
 
-Have a look at existing implementation like `build_llama`, `build_dbrx` or `build_bert`.
+Have a look at existing implementations like `build_llama`, `build_dbrx` or `build_bert`.
 
-When implementing a new graph, please note that the underlying `ggml` backends might not support them all, support for missing backend operations can be added in another PR.
+Some `ggml` backends do not support all operations. Backend implementations can be added in a separate PR.
 
 Note: to debug the inference graph: you can use [llama-eval-callback](/examples/eval-callback/).
 
index 9d0e7489f8208c0a268c42255b0ddbaec001f986..cf05bf47ece0830cbf699d8580df2015b5208522 100644 (file)
@@ -244,6 +244,7 @@ class MODEL_ARCH(IntEnum):
     QWEN2VL          = auto()
     PHI2             = auto()
     PHI3             = auto()
+    PHIMOE           = auto()
     PLAMO            = auto()
     CODESHELL        = auto()
     ORION            = auto()
@@ -428,6 +429,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.QWEN2VL:          "qwen2vl",
     MODEL_ARCH.PHI2:             "phi2",
     MODEL_ARCH.PHI3:             "phi3",
+    MODEL_ARCH.PHIMOE:           "phimoe",
     MODEL_ARCH.PLAMO:            "plamo",
     MODEL_ARCH.CODESHELL:        "codeshell",
     MODEL_ARCH.ORION:            "orion",
@@ -940,6 +942,24 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_DOWN,
         MODEL_TENSOR.FFN_UP,
     ],
+    MODEL_ARCH.PHIMOE: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ROPE_FACTORS_LONG,
+        MODEL_TENSOR.ROPE_FACTORS_SHORT,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_QKV,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_GATE_INP,
+        MODEL_TENSOR.FFN_GATE_EXP,
+        MODEL_TENSOR.FFN_DOWN_EXP,
+        MODEL_TENSOR.FFN_UP_EXP,
+    ],
     MODEL_ARCH.CODESHELL: [
         MODEL_TENSOR.TOKEN_EMBD,
         MODEL_TENSOR.POS_EMBD,
index efe2a4aa4fe281638df047d1472df4dc540a9719..7616c468a53015f9a6f6d460ac447cf6e61915a9 100644 (file)
@@ -55,7 +55,7 @@ class TensorNameMap:
         # Output
         MODEL_TENSOR.OUTPUT: (
             "embed_out",                 # gptneox
-            "lm_head",                   # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2
+            "lm_head",                   # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone olmoe olmo2 phimoe
             "output",                    # llama-pth bloom internlm2
             "word_embeddings_for_head",  # persimmon
             "lm_head.linear",            # phi2
@@ -68,7 +68,7 @@ class TensorNameMap:
         MODEL_TENSOR.OUTPUT_NORM: (
             "gpt_neox.final_layer_norm",               # gptneox
             "transformer.ln_f",                        # gpt2 gpt-j falcon jais exaone
-            "model.norm",                              # llama-hf baichuan internlm2 olmoe olmo2
+            "model.norm",                              # llama-hf baichuan internlm2 olmoe olmo2 phimoe
             "norm",                                    # llama-pth
             "transformer.norm_f",                      # mpt dbrx
             "ln_f",                                    # refact bloom qwen gpt2
@@ -108,7 +108,7 @@ class TensorNameMap:
             "transformer.h.{bid}.input_layernorm",                  # falcon7b
             "h.{bid}.input_layernorm",                              # bloom
             "transformer.h.{bid}.ln_mlp",                           # falcon40b
-            "model.layers.{bid}.input_layernorm",                   # llama-hf nemotron olmoe
+            "model.layers.{bid}.input_layernorm",                   # llama-hf nemotron olmoe phimoe
             "layers.{bid}.attention_norm",                          # llama-pth
             "language_model.encoder.layers.{bid}.input_layernorm",  # persimmon
             "model.layers.{bid}.ln1",                               # yi
@@ -152,7 +152,7 @@ class TensorNameMap:
 
         # Attention query
         MODEL_TENSOR.ATTN_Q: (
-            "model.layers.{bid}.self_attn.q_proj",                       # llama-hf nemotron olmoe olmo2
+            "model.layers.{bid}.self_attn.q_proj",                       # llama-hf nemotron olmoe olmo2 phimoe
             "model.layers.{bid}.self_attn.q_proj_no_perm",               # llama-custom
             "layers.{bid}.attention.wq",                                 # llama-pth
             "encoder.layer.{bid}.attention.self.query",                  # bert
@@ -165,7 +165,7 @@ class TensorNameMap:
 
         # Attention key
         MODEL_TENSOR.ATTN_K: (
-            "model.layers.{bid}.self_attn.k_proj",                     # llama-hf nemotron olmoe olmo2
+            "model.layers.{bid}.self_attn.k_proj",                     # llama-hf nemotron olmoe olmo2 phimoe
             "model.layers.{bid}.self_attn.k_proj_no_perm",             # llama-custom
             "layers.{bid}.attention.wk",                               # llama-pth
             "encoder.layer.{bid}.attention.self.key",                  # bert
@@ -179,7 +179,7 @@ class TensorNameMap:
 
         # Attention value
         MODEL_TENSOR.ATTN_V: (
-            "model.layers.{bid}.self_attn.v_proj",                       # llama-hf nemotron olmoe olmo2
+            "model.layers.{bid}.self_attn.v_proj",                       # llama-hf nemotron olmoe olmo2 phimoe
             "layers.{bid}.attention.wv",                                 # llama-pth
             "encoder.layer.{bid}.attention.self.value",                  # bert
             "transformer.h.{bid}.attn.v_proj",                           # gpt-j
@@ -197,7 +197,7 @@ class TensorNameMap:
             "transformer.blocks.{bid}.attn.out_proj",                       # mpt
             "transformer.h.{bid}.self_attention.dense",                     # falcon
             "h.{bid}.self_attention.dense",                                 # bloom
-            "model.layers.{bid}.self_attn.o_proj",                          # llama-hf nemotron olmoe olmo2
+            "model.layers.{bid}.self_attn.o_proj",                          # llama-hf nemotron olmoe olmo2 phimoe
             "model.layers.{bid}.self_attn.linear_attn",                     # deci
             "layers.{bid}.attention.wo",                                    # llama-pth
             "encoder.layer.{bid}.attention.output.dense",                   # bert
@@ -242,7 +242,7 @@ class TensorNameMap:
             "transformer.h.{bid}.ln_2",                                      # gpt2 refact qwen jais exaone
             "h.{bid}.post_attention_layernorm",                              # bloom
             "transformer.blocks.{bid}.norm_2",                               # mpt
-            "model.layers.{bid}.post_attention_layernorm",                   # llama-hf nemotron olmoe
+            "model.layers.{bid}.post_attention_layernorm",                   # llama-hf nemotron olmoe phimoe
             "layers.{bid}.ffn_norm",                                         # llama-pth
             "language_model.encoder.layers.{bid}.post_attention_layernorm",  # persimmon
             "model.layers.{bid}.ln2",                                        # yi
@@ -265,7 +265,7 @@ class TensorNameMap:
 
         MODEL_TENSOR.FFN_GATE_INP: (
             "layers.{bid}.feed_forward.gate",                   # mixtral
-            "model.layers.{bid}.block_sparse_moe.gate",         # mixtral
+            "model.layers.{bid}.block_sparse_moe.gate",         # mixtral phimoe
             "model.layers.{bid}.mlp.gate",                      # qwen2moe olmoe
             "transformer.decoder_layer.{bid}.router",           # Grok
             "transformer.blocks.{bid}.ffn.router.layer",        # dbrx
@@ -310,10 +310,11 @@ class TensorNameMap:
         ),
 
         MODEL_TENSOR.FFN_UP_EXP: (
-            "layers.{bid}.feed_forward.experts.w3",          # mixtral (merged)
-            "transformer.decoder_layer.{bid}.moe.linear_v",  # Grok (merged)
-            "transformer.blocks.{bid}.ffn.experts.mlp.v1",   # dbrx
-            "model.layers.{bid}.mlp.experts.up_proj",        # qwen2moe olmoe (merged)
+            "layers.{bid}.feed_forward.experts.w3",           # mixtral (merged)
+            "transformer.decoder_layer.{bid}.moe.linear_v",   # Grok (merged)
+            "transformer.blocks.{bid}.ffn.experts.mlp.v1",    # dbrx
+            "model.layers.{bid}.mlp.experts.up_proj",         # qwen2moe olmoe (merged)
+            "model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged)
         ),
 
         MODEL_TENSOR.FFN_UP_SHEXP: (
@@ -342,10 +343,11 @@ class TensorNameMap:
         ),
 
         MODEL_TENSOR.FFN_GATE_EXP: (
-            "layers.{bid}.feed_forward.experts.w1",         # mixtral (merged)
-            "transformer.decoder_layer.{bid}.moe.linear",   # Grok (merged)
-            "transformer.blocks.{bid}.ffn.experts.mlp.w1",  # dbrx
-            "model.layers.{bid}.mlp.experts.gate_proj",     # qwen2moe olmoe (merged)
+            "layers.{bid}.feed_forward.experts.w1",           # mixtral (merged)
+            "transformer.decoder_layer.{bid}.moe.linear",     # Grok (merged)
+            "transformer.blocks.{bid}.ffn.experts.mlp.w1",    # dbrx
+            "model.layers.{bid}.mlp.experts.gate_proj",       # qwen2moe olmoe (merged)
+            "model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged)
         ),
 
         MODEL_TENSOR.FFN_GATE_SHEXP: (
@@ -387,6 +389,7 @@ class TensorNameMap:
             "transformer.blocks.{bid}.ffn.experts.mlp.w2",       # dbrx
             "model.layers.{bid}.mlp.experts.down_proj",          # qwen2moe olmoe (merged)
             "model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
+            "model.layers.{bid}.block_sparse_moe.experts.w2",    # phimoe (merged)
         ),
 
         MODEL_TENSOR.FFN_DOWN_SHEXP: (
index 007d79f8261bceed80a448285a68135b96ade599..eef66ed311d7a202e9ebbff78c59cd2ffda7d55a 100644 (file)
@@ -27,6 +27,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_QWEN2VL,          "qwen2vl"          },
     { LLM_ARCH_PHI2,             "phi2"             },
     { LLM_ARCH_PHI3,             "phi3"             },
+    { LLM_ARCH_PHIMOE,           "phimoe"           },
     { LLM_ARCH_PLAMO,            "plamo"            },
     { LLM_ARCH_CODESHELL,        "codeshell"        },
     { LLM_ARCH_ORION,            "orion"            },
@@ -584,6 +585,27 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_FFN_UP,             "blk.%d.ffn_up" },
         },
     },
+    {
+        LLM_ARCH_PHIMOE,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,         "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,        "output_norm" },
+            { LLM_TENSOR_OUTPUT,             "output" },
+            { LLM_TENSOR_ROPE_FACTORS_LONG,  "rope_factors_long" },
+            { LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" },
+            { LLM_TENSOR_ATTN_NORM,          "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_QKV,           "blk.%d.attn_qkv" },
+            { LLM_TENSOR_ATTN_Q,             "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,             "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,             "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,           "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_NORM,           "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE_INP,       "blk.%d.ffn_gate_inp" },
+            { LLM_TENSOR_FFN_GATE_EXPS,      "blk.%d.ffn_gate_exps" },
+            { LLM_TENSOR_FFN_DOWN_EXPS,      "blk.%d.ffn_down_exps" },
+            { LLM_TENSOR_FFN_UP_EXPS,        "blk.%d.ffn_up_exps" },
+        },
+    },
     {
         LLM_ARCH_PLAMO,
         {
index 45e458bb9ccb5f032646a12af9daadf50c8e0fc1..2e5f97b771d0ef0ace44f5f0b0d36c7ade35a632 100644 (file)
@@ -31,6 +31,7 @@ enum llm_arch {
     LLM_ARCH_QWEN2VL,
     LLM_ARCH_PHI2,
     LLM_ARCH_PHI3,
+    LLM_ARCH_PHIMOE,
     LLM_ARCH_PLAMO,
     LLM_ARCH_CODESHELL,
     LLM_ARCH_ORION,
index 7deb3683bbccb95d044ae5e26ef80bddbdd97800..7260cb155261b5282ba615b44b0919f7669b37dd 100644 (file)
@@ -76,6 +76,7 @@ const char * llm_type_name(llm_type type) {
         case MODEL_8x7B:          return "8x7B";
         case MODEL_8x22B:         return "8x22B";
         case MODEL_16x12B:        return "16x12B";
+        case MODEL_16x3_8B:       return "16x3.8B";
         case MODEL_10B_128x3_66B: return "10B+128x3.66B";
         case MODEL_57B_A14B:      return "57B.A14B";
         case MODEL_27B:           return "27B";
@@ -661,6 +662,15 @@ void llm_load_hparams(llama_model_loader & ml, llama_model & model) {
                     throw std::runtime_error("invalid value for sliding_window");
                 }
             } break;
+        case LLM_ARCH_PHIMOE:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+                switch (hparams.n_layer) {
+                    case 32: model.type = e_model::MODEL_16x3_8B; break;
+                    default: model.type = e_model::MODEL_UNKNOWN;
+                }
+            } break;
         case LLM_ARCH_PLAMO:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -2094,6 +2104,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_OLMOE:
         case LLM_ARCH_PHI2:
         case LLM_ARCH_PHI3:
+        case LLM_ARCH_PHIMOE:
         case LLM_ARCH_GEMMA:
         case LLM_ARCH_GEMMA2:
         case LLM_ARCH_STARCODER2:
index ce038932d4e2d6c475671a45a92805b215fbd2f2..424cb0f521943af787e0cc2274e96dbb094eee12 100644 (file)
@@ -73,6 +73,7 @@ enum llm_type {
     MODEL_8x7B,
     MODEL_8x22B,
     MODEL_16x12B,
+    MODEL_16x3_8B,
     MODEL_10B_128x3_66B,
     MODEL_57B_A14B,
     MODEL_27B,
index 97e716cd65563d2c65fbbd02571751a1f6fbca10..ae375bcd3c8b1fb914da3cdc9187340853831367 100644 (file)
@@ -1212,6 +1212,50 @@ static bool llm_load_tensors(
                         layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
                         layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }, 0);
 
+                        layer.rope_long  = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG,  "weight", i), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
+                        layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
+                    }
+                } break;
+            case LLM_ARCH_PHIMOE:
+                {
+                    const int64_t n_embd_head = n_embd / n_head;
+
+                    model.tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
+
+                    // output
+                    model.output_norm   = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
+                    model.output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd}, 0);
+                    model.output        = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), { n_embd, n_vocab }, 0);
+                    model.output_b      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "bias"),   { n_vocab }, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm   = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
+                        layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias",   i), { n_embd }, 0);
+
+                        layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd + 2 * n_embd_gqa }, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        if (layer.wqkv == nullptr) {
+                            layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0);
+                            layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias",   i), {n_embd}, 0);
+
+                            layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
+                            layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias",   i), {n_embd_gqa}, 0);
+
+                            layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
+                            layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias",   i), {n_embd_gqa}, 0);
+                        }
+                        layer.wo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd, n_embd }, 0);
+                        layer.bo   = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias",   i), { n_embd }, 0);
+
+                        layer.ffn_norm   = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0);
+                        layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias",   i), { n_embd }, 0);
+
+                        layer.ffn_gate_inp  = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP,  "weight", i), {n_embd, n_expert},         0);
+                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff,   n_expert}, 0);
+                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff,   n_embd, n_expert}, 0);
+                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd, n_ff,   n_expert}, 0);
+
                         layer.rope_long  = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG,  "weight", i), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
                         layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
                     }
@@ -6266,7 +6310,7 @@ struct llm_build_context {
 
                 struct ggml_tensor* attn_norm_output = llm_build_norm(ctx0, inpL, hparams,
                     model.layers[il].attn_norm,
-                    NULL,
+                    model.layers[il].attn_norm_b,
                     LLM_NORM_RMS, cb, il);
                 cb(attn_norm_output, "attn_norm", il);
 
@@ -6281,8 +6325,7 @@ struct llm_build_context {
                     Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0 * sizeof(float) * (n_embd)));
                     Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd)));
                     Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1 * sizeof(float) * (n_embd + n_embd_gqa)));
-                }
-                else {
+                } else {
                     Qcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, attn_norm_output), model.layers[il].bq);
                     Kcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, attn_norm_output), model.layers[il].bk);
                     Vcur = ggml_add(ctx0, llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, attn_norm_output), model.layers[il].bv);
@@ -6326,14 +6369,12 @@ struct llm_build_context {
             residual = cur;
 
             cur = llm_build_norm(ctx0, cur, hparams,
-                model.layers[il].ffn_norm, NULL,
+                model.layers[il].ffn_norm, model.layers[il].ffn_norm_b,
                 LLM_NORM_RMS, cb, il);
             cb(cur, "ffn_norm", il);
 
-            // FF
-            // special-case: the up and gate tensors are merged into a single tensor
-            // TOOD: support into llm_build_ffn
-            {
+            // feed-forward network
+            if (model.layers[il].ffn_gate_inp == nullptr) {
                 cur = llm_build_ffn(ctx0, lctx, cur,
                         model.layers[il].ffn_up,   NULL, NULL,
                         NULL,                      NULL, NULL,
@@ -6341,6 +6382,20 @@ struct llm_build_context {
                         NULL,
                         LLM_FFN_SWIGLU, LLM_FFN_SEQ, cb, il);
                 cb(cur, "ffn_out", il);
+            } else {
+                // MoE branch
+                cur = llm_build_moe_ffn(ctx0, lctx, cur,
+                        model.layers[il].ffn_gate_inp,
+                        model.layers[il].ffn_up_exps,
+                        model.layers[il].ffn_gate_exps,
+                        model.layers[il].ffn_down_exps,
+                        nullptr,
+                        n_expert, n_expert_used,
+                        LLM_FFN_SILU, true,
+                        false, 0.0,
+                        LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
+                        cb, il);
+                cb(cur, "ffn_moe_out", il);
             }
 
             cur = ggml_add(ctx0, residual, cur);
@@ -6353,11 +6408,16 @@ struct llm_build_context {
 
         cur = llm_build_norm(ctx0, inpL, hparams,
             model.output_norm,
-            NULL,
+            model.output_norm_b,
             LLM_NORM_RMS, cb, -1);
         cb(cur, "result_norm", -1);
 
         cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+
+        if (model.output_b != nullptr) {
+            cb(cur, "result_output_no_bias", -1);
+            cur = ggml_add(ctx0, cur, model.output_b);
+        }
         cb(cur, "result_output", -1);
 
         ggml_build_forward_expand(gf, cur);
@@ -10536,6 +10596,7 @@ static struct ggml_cgraph * llama_build_graph(
                 result = llm.build_phi2();
             } break;
         case LLM_ARCH_PHI3:
+        case LLM_ARCH_PHIMOE:
             {
                 result = llm.build_phi3();
             } break;