]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama: Add option to merge gate and exp weights (#19139)
authorAman Gupta <redacted>
Thu, 26 Feb 2026 13:01:08 +0000 (21:01 +0800)
committerGitHub <redacted>
Thu, 26 Feb 2026 13:01:08 +0000 (21:01 +0800)
* llama: Add option to merge gate and exp weights

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* update constants.py

* add gate_up for the all MoE models

* convert: simplify merge tensor condition

* update constants.py

* reduce number of models, add create_tensor_gate_up helper

---------

Co-authored-by: Sigbjørn Skjæret <redacted>
12 files changed:
convert_hf_to_gguf.py
gguf-py/gguf/constants.py
gguf-py/gguf/tensor_mapping.py
src/llama-arch.cpp
src/llama-arch.h
src/llama-graph.cpp
src/llama-graph.h
src/llama-model.cpp
src/llama-model.h
src/models/deepseek2.cpp
src/models/qwen35moe.cpp
src/models/qwen3next.cpp

index 5709cb0e3c19679c369b8e627306ed4c9aff3e85..09544173981afb424ce54e6463da0db003d72935 100755 (executable)
@@ -116,7 +116,8 @@ class ModelBase:
                  split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
                  small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None,
                  disable_mistral_community_chat_template: bool = False,
-                 sentence_transformers_dense_modules: bool = False):
+                 sentence_transformers_dense_modules: bool = False,
+                 fuse_gate_up_exps: bool = False):
         if type(self) is ModelBase or \
                 type(self) is TextModel or \
                 type(self) is MmprojModel:
@@ -135,6 +136,9 @@ class ModelBase:
         self.dry_run = dry_run
         self.remote_hf_model_id = remote_hf_model_id
         self.sentence_transformers_dense_modules = sentence_transformers_dense_modules
+        self.fuse_gate_up_exps = fuse_gate_up_exps
+        self._gate_exp_buffer: dict[int, Tensor] = {}
+        self._up_exp_buffer: dict[int, Tensor] = {}
         self.hparams = ModelBase.load_hparams(self.dir_model, self.is_mistral_format) if hparams is None else hparams
         self.model_tensors = self.index_tensors(remote_hf_model_id=remote_hf_model_id)
         self.metadata_override = metadata_override
@@ -512,8 +516,31 @@ class ModelBase:
         raise NotImplementedError("set_gguf_parameters() must be implemented in subclasses")
 
     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
-        del bid # unused
-        return [(self.map_tensor_name(name), data_torch)]
+        new_name = self.map_tensor_name(name)
+
+        # Handle gate/up expert tensor fusion if enabled
+        if self.fuse_gate_up_exps and bid is not None:
+            if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_GATE_EXP, bid):
+                self._gate_exp_buffer[bid] = data_torch
+            elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_UP_EXP, bid):
+                self._up_exp_buffer[bid] = data_torch
+
+            # Check if both gate and up are buffered for this layer
+            if bid in self._gate_exp_buffer and bid in self._up_exp_buffer:
+                gate_data = self._gate_exp_buffer.pop(bid)
+                up_data = self._up_exp_buffer.pop(bid)
+                # gate/up shape: (n_expert, n_ff, n_embd), concatenate to (n_expert, n_ff*2, n_embd)
+                fused_data = torch.cat([gate_data, up_data], dim=1)
+                fused_name = self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_UP_EXP, bid)
+                logger.info(f"Fused gate_exps and up_exps for layer {bid}")
+                return [(fused_name, fused_data)]
+
+            # If we buffered a gate/up tensor, wait for the other
+            if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_GATE_EXP, bid) or \
+               self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_UP_EXP, bid):
+                return []
+
+        return [(new_name, data_torch)]
 
     def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
         del name, new_name, bid, n_dims  # unused
@@ -11942,6 +11969,11 @@ def parse_args() -> argparse.Namespace:
               "Default these modules are not included.")
     )
 
+    parser.add_argument(
+        "--fuse-gate-up-exps", action="store_true",
+        help="Fuse gate_exps and up_exps tensors into a single gate_up_exps tensor for MoE models.",
+    )
+
     args = parser.parse_args()
     if not args.print_supported_models and args.model is None:
         parser.error("the following arguments are required: model")
@@ -12079,7 +12111,8 @@ def main() -> None:
                                      split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run,
                                      small_first_shard=args.no_tensor_first_split,
                                      remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template,
-                                     sentence_transformers_dense_modules=args.sentence_transformers_dense_modules
+                                     sentence_transformers_dense_modules=args.sentence_transformers_dense_modules,
+                                     fuse_gate_up_exps=args.fuse_gate_up_exps
                                      )
 
         if args.vocab_only:
index e9ef97d553c8ebdefcb0b985986527038a4b3d90..839c6e787fc0db9db0eb5ff94b9894dd04bea961 100644 (file)
@@ -532,6 +532,7 @@ class MODEL_TENSOR(IntEnum):
     FFN_GATE_EXP         = auto()
     FFN_DOWN_EXP         = auto()
     FFN_UP_EXP           = auto()
+    FFN_GATE_UP_EXP      = auto()
     FFN_GATE_SHEXP       = auto()
     FFN_DOWN_SHEXP       = auto()
     FFN_UP_SHEXP         = auto()
@@ -980,6 +981,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
     MODEL_TENSOR.FFN_GATE_EXP:              "blk.{bid}.ffn_gate_exps",
     MODEL_TENSOR.FFN_DOWN_EXP:              "blk.{bid}.ffn_down_exps",
     MODEL_TENSOR.FFN_UP_EXP:                "blk.{bid}.ffn_up_exps",
+    MODEL_TENSOR.FFN_GATE_UP_EXP:           "blk.{bid}.ffn_gate_up_exps",
     MODEL_TENSOR.FFN_EXP_PROBS_B:           "blk.{bid}.exp_probs_b",
     MODEL_TENSOR.LAYER_OUT_NORM:            "blk.{bid}.layer_output_norm",
     MODEL_TENSOR.PER_LAYER_TOKEN_EMBD:      "per_layer_token_embd",           # gemma3n
@@ -1820,6 +1822,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_DOWN_EXP,
         MODEL_TENSOR.FFN_UP_EXP,
         MODEL_TENSOR.FFN_GATE_EXP,
+        MODEL_TENSOR.FFN_GATE_UP_EXP,
         MODEL_TENSOR.SSM_A,
         MODEL_TENSOR.SSM_CONV1D,
         MODEL_TENSOR.SSM_DT,
@@ -1909,6 +1912,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_DOWN_EXP,
         MODEL_TENSOR.FFN_UP_EXP,
         MODEL_TENSOR.FFN_GATE_EXP,
+        MODEL_TENSOR.FFN_GATE_UP_EXP,
         MODEL_TENSOR.SSM_A,
         MODEL_TENSOR.SSM_CONV1D,
         MODEL_TENSOR.SSM_DT,
@@ -2610,6 +2614,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_GATE_EXP,
         MODEL_TENSOR.FFN_DOWN_EXP,
         MODEL_TENSOR.FFN_UP_EXP,
+        MODEL_TENSOR.FFN_GATE_UP_EXP,
         MODEL_TENSOR.FFN_GATE_SHEXP,
         MODEL_TENSOR.FFN_DOWN_SHEXP,
         MODEL_TENSOR.FFN_UP_SHEXP,
index fc468d07745a6edd98898f01716f432c094e0e70..e575610900e7557100b8c5b1344ab28c90f27822 100644 (file)
@@ -567,6 +567,10 @@ class TensorNameMap:
             "model.layers.{bid}.mlp.chunk_experts.gate_proj",           # grovemoe
         ),
 
+        MODEL_TENSOR.FFN_GATE_UP_EXP: (
+            "model.layers.{bid}.mlp.experts.gate_up_proj",
+        ),
+
         # Feed-forward down
         MODEL_TENSOR.FFN_DOWN: (
             "gpt_neox.layers.{bid}.mlp.dense_4h_to_h",                # gptneox
index 106ea133ac0bcc3c8895fff8a876afb5e805d059..47e8d5278acdc7521c537021ba77fb17124d7e6c 100644 (file)
@@ -349,6 +349,7 @@ static const std::map<llm_tensor, const char *> LLM_TENSOR_NAMES = {
     { LLM_TENSOR_FFN_DOWN_EXP,                           "blk.%d.ffn_down.%d" },
     { LLM_TENSOR_FFN_UP_EXP,                             "blk.%d.ffn_up.%d" },
     { LLM_TENSOR_FFN_GATE_EXPS,                          "blk.%d.ffn_gate_exps" },
+    { LLM_TENSOR_FFN_GATE_UP_EXPS,                       "blk.%d.ffn_gate_up_exps" },
     { LLM_TENSOR_FFN_DOWN_EXPS,                          "blk.%d.ffn_down_exps" },
     { LLM_TENSOR_FFN_UP_EXPS,                            "blk.%d.ffn_up_exps" },
     { LLM_TENSOR_ATTN_POST_NORM,                         "blk.%d.post_attention_norm" },
@@ -1004,6 +1005,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
                 LLM_TENSOR_FFN_GATE_EXPS,
                 LLM_TENSOR_FFN_DOWN_EXPS,
                 LLM_TENSOR_FFN_UP_EXPS,
+                LLM_TENSOR_FFN_GATE_UP_EXPS,
                 LLM_TENSOR_FFN_GATE_INP_SHEXP,
                 LLM_TENSOR_FFN_GATE_SHEXP,
                 LLM_TENSOR_FFN_DOWN_SHEXP,
@@ -1061,6 +1063,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
                 LLM_TENSOR_FFN_GATE_EXPS,
                 LLM_TENSOR_FFN_DOWN_EXPS,
                 LLM_TENSOR_FFN_UP_EXPS,
+                LLM_TENSOR_FFN_GATE_UP_EXPS,
                 LLM_TENSOR_FFN_GATE_INP_SHEXP,
                 LLM_TENSOR_FFN_GATE_SHEXP,
                 LLM_TENSOR_FFN_DOWN_SHEXP,
@@ -1601,6 +1604,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
                 LLM_TENSOR_FFN_GATE_EXPS,
                 LLM_TENSOR_FFN_DOWN_EXPS,
                 LLM_TENSOR_FFN_UP_EXPS,
+                LLM_TENSOR_FFN_GATE_UP_EXPS,
                 LLM_TENSOR_FFN_GATE_INP_SHEXP,
                 LLM_TENSOR_FFN_GATE_SHEXP,
                 LLM_TENSOR_FFN_DOWN_SHEXP,
@@ -2685,6 +2689,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
     {LLM_TENSOR_FFN_DOWN_EXPS,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
     {LLM_TENSOR_FFN_GATE_EXPS,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
     {LLM_TENSOR_FFN_UP_EXPS,                {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
+    {LLM_TENSOR_FFN_GATE_UP_EXPS,           {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
     {LLM_TENSOR_FFN_DOWN_CHEXPS,            {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
     {LLM_TENSOR_FFN_GATE_CHEXPS,            {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
     {LLM_TENSOR_FFN_UP_CHEXPS,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
index 42a6ea38f38d4811d1e1428f5be0655586ed09c7..6d1b1df31c0bd1338100b718ff77bd0910e39618 100644 (file)
@@ -373,6 +373,7 @@ enum llm_tensor {
     LLM_TENSOR_FFN_DOWN_EXPS, // merged experts
     LLM_TENSOR_FFN_GATE_EXPS,
     LLM_TENSOR_FFN_UP_EXPS,
+    LLM_TENSOR_FFN_GATE_UP_EXPS,
     LLM_TENSOR_FFN_DOWN_SHEXP,
     LLM_TENSOR_FFN_GATE_SHEXP,
     LLM_TENSOR_FFN_UP_SHEXP,
index dc58c0826aee1d63dd456f2d94b248205453b216..23a86ea2905fba1fb26354c159d0221bb8b99ff2 100644 (file)
@@ -1165,7 +1165,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
                float   w_scale,
          llama_expert_gating_func_type gating_op,
                  int   il,
-         ggml_tensor * probs_in) const {
+         ggml_tensor * probs_in,
+         ggml_tensor * gate_up_exps) const {
     return build_moe_ffn(
         cur,
         gate_inp,  /* gate_inp_b  */ nullptr,
@@ -1181,7 +1182,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
         w_scale,
         gating_op,
         il,
-        probs_in
+        probs_in,
+        gate_up_exps
     );
 }
 
@@ -1204,7 +1206,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
                float   w_scale,
         llama_expert_gating_func_type gating_op,
                  int   il,
-         ggml_tensor * probs_in) const {
+         ggml_tensor * probs_in,
+         ggml_tensor * gate_up_exps,
+         ggml_tensor * gate_up_exps_b) const {
     const int64_t n_embd   = cur->ne[0];
     const int64_t n_tokens = cur->ne[1];
     const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
@@ -1343,27 +1347,49 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
         cb(cur, "ffn_moe_weighted", il);
     }
 
-    ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
-    cb(up, "ffn_moe_up", il);
+    ggml_tensor * up = nullptr;
+    ggml_tensor * experts = nullptr;
 
-    if (up_exps_b) {
-        up = ggml_add_id(ctx0, up, up_exps_b, selected_experts);
-        cb(up, "ffn_moe_up_biased", il);
-    }
+    if (gate_up_exps) {
+        // merged gate_up path: one mul_mat_id, then split into gate and up views
+        ggml_tensor * gate_up = build_lora_mm_id(gate_up_exps, cur, selected_experts); // [n_ff*2, n_expert_used, n_tokens]
+        cb(gate_up, "ffn_moe_gate_up", il);
 
-    ggml_tensor * experts = nullptr;
-    if (gate_exps) {
-        cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
+        if (gate_up_exps_b) {
+            gate_up = ggml_add_id(ctx0, gate_up, gate_up_exps_b, selected_experts);
+            cb(gate_up, "ffn_moe_gate_up_biased", il);
+        }
+
+        const int64_t n_ff = gate_up->ne[0] / 2;
+        cur = ggml_view_3d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->ne[2], gate_up->nb[1], gate_up->nb[2], 0);
         cb(cur, "ffn_moe_gate", il);
+        up  = ggml_view_3d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->ne[2], gate_up->nb[1], gate_up->nb[2], n_ff * gate_up->nb[0]);
+        cb(up, "ffn_moe_up", il);
     } else {
-        cur = up;
-    }
+        // separate gate and up path
+        up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
+        cb(up, "ffn_moe_up", il);
+
+        if (up_exps_b) {
+            up = ggml_add_id(ctx0, up, up_exps_b, selected_experts);
+            cb(up, "ffn_moe_up_biased", il);
+        }
 
-    if (gate_exps_b) {
-        cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts);
-        cb(cur, "ffn_moe_gate_biased", il);
+        if (gate_exps) {
+            cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
+            cb(cur, "ffn_moe_gate", il);
+        } else {
+            cur = up;
+        }
+
+        if (gate_exps_b) {
+            cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts);
+            cb(cur, "ffn_moe_gate_biased", il);
+        }
     }
 
+    const bool has_gate = gate_exps || gate_up_exps;
+
     switch (type_op) {
         case LLM_FFN_SILU:
             if (gate_exps) {
@@ -1385,7 +1411,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
                         break;
                     }
                 }
+            }
 
+            if (has_gate) {
                 cur = ggml_swiglu_split(ctx0, cur, up);
                 cb(cur, "ffn_moe_swiglu", il);
             } else {
@@ -1393,7 +1421,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
                 cb(cur, "ffn_moe_silu", il);
             } break;
         case LLM_FFN_GELU:
-            if (gate_exps) {
+            if (has_gate) {
                 cur = ggml_geglu_split(ctx0, cur, up);
                 cb(cur, "ffn_moe_geglu", il);
             } else {
@@ -1409,7 +1437,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
                 cb(cur, "ffn_moe_swiglu_oai", il);
             } break;
         case LLM_FFN_RELU:
-            if (gate_exps) {
+            if (has_gate) {
                 cur = ggml_reglu_split(ctx0, cur, up);
                 cb(cur, "ffn_moe_reglu", il);
             } else {
@@ -1417,7 +1445,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
                 cb(cur, "ffn_moe_relu", il);
             } break;
         case LLM_FFN_RELU_SQR:
-            if (gate_exps) {
+            if (has_gate) {
                 // TODO: add support for gated squared relu
                 GGML_ABORT("fatal error: gated squared relu not implemented");
             } else {
index 22d11a8385706b44877e8af6213b60fc4f86c883..e8f006977d2e4b3d02af4a1ce8a90bfc55744788 100644 (file)
@@ -814,7 +814,8 @@ struct llm_graph_context {
                    float   w_scale,
             llama_expert_gating_func_type gating_op,
                      int   il,
-             ggml_tensor * probs_in = nullptr) const;
+             ggml_tensor * probs_in = nullptr,
+             ggml_tensor * gate_up_exps = nullptr) const;
 
     ggml_tensor * build_moe_ffn(
              ggml_tensor * cur,
@@ -835,7 +836,9 @@ struct llm_graph_context {
                    float   w_scale,
             llama_expert_gating_func_type gating_op,
                      int   il,
-             ggml_tensor * probs_in = nullptr) const;
+             ggml_tensor * probs_in = nullptr,
+             ggml_tensor * gate_up_exps = nullptr,
+             ggml_tensor * gate_up_exps_b = nullptr) const;
 
     //
     // inputs
index c8ef1b5e7c2906a417f172fb6161f0148c738c65..dabf3b3086ef8827a09615532a090fbdd8d6d9c7 100644 (file)
@@ -2980,6 +2980,15 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
 
         // TODO: move to a separate function
         const auto tn = LLM_TN(arch);
+
+        // helper: try merged gate_up_exps first, fall back to separate gate and up
+        auto create_tensor_gate_up_exps = [&](llama_layer & layer, int bid, int64_t n_embd_, int64_t n_ff_, int64_t n_expert_, int flags) {
+            layer.ffn_gate_up_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_UP_EXPS, "weight", bid), {n_embd_, n_ff_ * 2, n_expert_}, TENSOR_NOT_REQUIRED);
+            if (layer.ffn_gate_up_exps == nullptr) {
+                layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", bid), {n_embd_, n_ff_, n_expert_}, flags);
+                layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", bid), {n_embd_, n_ff_, n_expert_}, flags);
+            }
+        };
         switch (arch) {
             case LLM_ARCH_LLAMA:
             case LLM_ARCH_REFACT:
@@ -5221,9 +5230,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                             }
 
                             // MoE branch
-                            layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
                             layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert}, 0);
-                            layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
+                            create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0);
 
                             // Shared expert branch
                             layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
@@ -7425,9 +7433,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         }
 
                         layer.ffn_gate_inp  = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP,  "weight", i), { n_embd, n_expert }, 0);
-                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
                         layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0);
-                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
+                        create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0);
 
                         // Shared experts
                         layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0);
@@ -7491,9 +7498,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         }
 
                         layer.ffn_gate_inp  = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP,  "weight", i), { n_embd, n_expert }, 0);
-                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
                         layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0);
-                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
+                        create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0);
 
                         // Shared experts
                         const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff;
index 96e407a0b307c6ab916934a82d9e8410a5d29460..d7c3e7d1c1a32f378d6602044ae919981ff9e39f 100644 (file)
@@ -280,14 +280,16 @@ struct llama_layer {
     struct ggml_tensor * ffn_up_enc   = nullptr;
 
     // ff MoE
-    struct ggml_tensor * ffn_gate_inp    = nullptr;
-    struct ggml_tensor * ffn_gate_exps   = nullptr;
-    struct ggml_tensor * ffn_down_exps   = nullptr;
-    struct ggml_tensor * ffn_up_exps     = nullptr;
-    struct ggml_tensor * ffn_gate_inp_b  = nullptr;
-    struct ggml_tensor * ffn_gate_exps_b = nullptr;
-    struct ggml_tensor * ffn_down_exps_b = nullptr;
-    struct ggml_tensor * ffn_up_exps_b   = nullptr;
+    struct ggml_tensor * ffn_gate_inp      = nullptr;
+    struct ggml_tensor * ffn_gate_exps     = nullptr;
+    struct ggml_tensor * ffn_down_exps     = nullptr;
+    struct ggml_tensor * ffn_up_exps       = nullptr;
+    struct ggml_tensor * ffn_gate_up_exps  = nullptr;
+    struct ggml_tensor * ffn_gate_inp_b    = nullptr;
+    struct ggml_tensor * ffn_gate_exps_b   = nullptr;
+    struct ggml_tensor * ffn_down_exps_b   = nullptr;
+    struct ggml_tensor * ffn_up_exps_b     = nullptr;
+    struct ggml_tensor * ffn_gate_up_exps_b = nullptr;
 
     // ff shared expert (shexp)
     struct ggml_tensor * ffn_gate_inp_shexp = nullptr;
index b2c1f16060130064c59423740e08ca5aa14d06d0..b608396e50ea73a2806ef4dc1c5d39e23136842e 100644 (file)
@@ -218,7 +218,9 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
                 LLM_FFN_SILU, hparams.expert_weights_norm,
                 hparams.expert_weights_scale, hparams.expert_weights_scale,
                 (llama_expert_gating_func_type) hparams.expert_gating_func,
-                il);
+                il,
+                nullptr,
+                model.layers[il].ffn_gate_up_exps);
             cb(moe_out, "ffn_moe_out", il);
 
             // FFN shared expert
index 77f18b5aeb8e4abdc14d60b3fd857fd6097342f9..22d708f20622261be4a90ae150f86e8c4c6834d2 100644 (file)
@@ -380,7 +380,8 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_ffn(ggml_tensor * cur, const int
             model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps,
             nullptr,
             n_expert, n_expert_used, LLM_FFN_SILU,
-            true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
+            true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il,
+            nullptr, model.layers[il].ffn_gate_up_exps);
     cb(moe_out, "ffn_moe_out", il);
 
     // Add shared experts if present - following Qwen3Next reference implementation
index 9d3a68cfe5e86c6d7966e7b25b6cbcf224147777..f2621200f2337e95547956213438d4222b59bfbc 100644 (file)
@@ -479,7 +479,8 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int
                 model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps,
                 nullptr,
                 n_expert, n_expert_used, LLM_FFN_SILU,
-                true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il);
+                true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il,
+                nullptr, model.layers[il].ffn_gate_up_exps);
         cb(moe_out, "ffn_moe_out", il);
 
         // Add shared experts if present - following Qwen3Next reference implementation