]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
model: support MiMo-V2-Flash (#18328)
authorXuan-Son Nguyen <redacted>
Wed, 24 Dec 2025 22:07:08 +0000 (23:07 +0100)
committerGitHub <redacted>
Wed, 24 Dec 2025 22:07:08 +0000 (23:07 +0100)
* mimov2: convert ok

* rename mimov2 --> mimo2

* fix conversion

* runnable not incorrect

* use sink

* add_sliding_window_pattern

* add swa and per-layer n_head_kv

* correct params

* somewhat working

* correct gating func

* nits

* mimo2: wire RMS eps + MoE bias + converter guards

* add co-author

Co-authored-by: Aaryan-Kapoor <redacted>
* use add_rope_freq_base_swa

---------

Co-authored-by: Aaryan Kapoor <redacted>
Co-authored-by: Aaryan-Kapoor <redacted>
convert_hf_to_gguf.py
gguf-py/gguf/constants.py
gguf-py/gguf/tensor_mapping.py
src/CMakeLists.txt
src/llama-arch.cpp
src/llama-arch.h
src/llama-hparams.h
src/llama-model.cpp
src/llama-model.h
src/models/mimo2-iswa.cpp [new file with mode: 0644]
src/models/models.h

index c638e3398665ea914aa4a2edbac39c7dbb4ef210..69abb7367d4882044866d7fbad251dff6aaa4727 100755 (executable)
@@ -7362,6 +7362,90 @@ class MiniMaxM2Model(TextModel):
         return super().modify_tensors(data_torch, name, bid)
 
 
+@ModelBase.register("MiMoV2FlashForCausalLM")
+class MimoV2Model(TextModel):
+    model_arch = gguf.MODEL_ARCH.MIMO2
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+
+        assert self.hparams["swa_head_dim"] == self.hparams["head_dim"]
+        assert self.hparams["swa_num_attention_heads"] == self.hparams["num_attention_heads"]
+        assert self.hparams["swa_v_head_dim"] == self.hparams["v_head_dim"]
+        assert self.hparams["topk_method"] == "noaux_tc"
+
+        n_head_kv = self.hparams["num_key_value_heads"]
+        n_head_kv_swa = self.hparams["swa_num_key_value_heads"]
+        n_head_kv_arr = [n_head_kv_swa if use_swa == 1 else n_head_kv for use_swa in self.hparams["hybrid_layer_pattern"]]
+        self.gguf_writer.add_head_count_kv(n_head_kv_arr)
+
+        self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
+        self.gguf_writer.add_sliding_window_pattern(self.hparams["hybrid_layer_pattern"])
+        self.gguf_writer.add_rope_freq_base_swa(self.hparams["swa_rope_theta"])
+        self.gguf_writer.add_value_length(self.hparams["v_head_dim"])
+        self.gguf_writer.add_expert_count(self.hparams["n_routed_experts"])
+        self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"])
+
+        rope_dim = int(self.hparams["head_dim"] * self.hparams["partial_rotary_factor"])
+        self.gguf_writer.add_rope_dimension_count(rope_dim)
+
+        self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("layernorm_epsilon", 1e-5))
+
+    _experts: list[dict[str, Tensor]] | None = None
+
+    def modify_tensors(self, data_torch, name, bid):
+        if name.endswith("e_score_correction_bias"):
+            name = name.replace("e_score_correction_bias", "e_score_correction.bias")
+
+        if "attention_sink" in name and not name.endswith(".weight"):
+            name += ".weight"
+
+        # TODO: mimo v2 does not indicate the number of next-token-prediction layers, therefore we cannot do the same way as GLM4_MOE
+        if "model.mtp." in name:
+            return []
+
+        # process the experts separately
+        if name.find("mlp.experts") != -1:
+            n_experts = self.hparams["n_routed_experts"]
+            assert bid is not None
+
+            if self._experts is None:
+                self._experts = [{} for _ in range(self.block_count)]
+
+            self._experts[bid][name] = data_torch
+
+            if len(self._experts[bid]) >= n_experts * 3:
+                tensors: list[tuple[str, Tensor]] = []
+
+                # merge the experts into a single 3d tensor
+                for w_name in ["gate_proj", "up_proj", "down_proj"]:
+                    datas: list[Tensor] = []
+
+                    for xid in range(n_experts):
+                        ename_to_retrieve = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
+                        datas.append(self._experts[bid][ename_to_retrieve])
+                        del self._experts[bid][ename_to_retrieve]
+
+                    data_torch = torch.stack(datas, dim=0)
+                    merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
+                    new_name = self.map_tensor_name(merged_name)
+                    tensors.append((new_name, data_torch))
+
+                return tensors
+            else:
+                return []
+        return [(self.map_tensor_name(name), data_torch)]
+
+    def prepare_tensors(self):
+        super().prepare_tensors()
+
+        if self._experts is not None:
+            # flatten `list[dict[str, Tensor]]` into `list[str]`
+            experts = [k for d in self._experts for k in d.keys()]
+            if len(experts) > 0:
+                raise ValueError(f"Unprocessed experts: {experts}")
+
+
 @ModelBase.register("PanguEmbeddedForCausalLM")
 class PanguEmbeddedModel(TextModel):
     model_arch = gguf.MODEL_ARCH.PANGU_EMBED
index baff8547abe101485c6a56229814b4aecdf15907..27578daaf9da7272e00c5f6b396f15daed967bd7 100644 (file)
@@ -449,6 +449,7 @@ class MODEL_ARCH(IntEnum):
     RND1             = auto()
     PANGU_EMBED      = auto()
     MISTRAL3         = auto()
+    MIMO2            = auto()
     LLAMA_EMBED      = auto()
 
 
@@ -845,6 +846,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.RND1:             "rnd1",
     MODEL_ARCH.PANGU_EMBED:      "pangu-embedded",
     MODEL_ARCH.MISTRAL3:         "mistral3",
+    MODEL_ARCH.MIMO2:            "mimo2",
     MODEL_ARCH.LLAMA_EMBED:      "llama-embed",
 }
 
@@ -3198,6 +3200,26 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_DOWN_EXP,
         MODEL_TENSOR.FFN_UP_EXP,
     ],
+    MODEL_ARCH.MIMO2: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_SINKS,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_GATE,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+        MODEL_TENSOR.FFN_GATE_INP,
+        MODEL_TENSOR.FFN_GATE_EXP,
+        MODEL_TENSOR.FFN_DOWN_EXP,
+        MODEL_TENSOR.FFN_UP_EXP,
+        MODEL_TENSOR.FFN_EXP_PROBS_B,
+    ],
     MODEL_ARCH.LLAMA_EMBED: [
         MODEL_TENSOR.TOKEN_EMBD,
         MODEL_TENSOR.OUTPUT_NORM,
@@ -3217,7 +3239,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_GATE_EXP,
         MODEL_TENSOR.FFN_DOWN_EXP,
         MODEL_TENSOR.FFN_UP_EXP,
-    ]
+    ],
     # TODO
 }
 
index 276720fcde9ea18578789c110002b03ac952d28d..1690d991f2b65125a0eeae3f045a93862aada32b 100644 (file)
@@ -320,6 +320,7 @@ class TensorNameMap:
 
         MODEL_TENSOR.ATTN_SINKS: (
             "model.layers.{bid}.self_attn.sinks", # openai-moe
+            "model.layers.{bid}.self_attn.attention_sink_bias", # mimov2
         ),
 
         MODEL_TENSOR.ATTN_GATE: (
index 4ca897491607a90eab62e05c2bf4b9d5ee47e8db..1e155534bd16d95904f293420a9607b9913d51b2 100644 (file)
@@ -88,6 +88,7 @@ add_library(llama
             models/llama-iswa.cpp
             models/llama.cpp
             models/mamba.cpp
+            models/mimo2-iswa.cpp
             models/minicpm3.cpp
             models/minimax-m2.cpp
             models/modern-bert.cpp
index 73420d3c9e16d882ba2325e4c0515f4e33f90925..75013d8d33088e98d1b4c24a62d0d81dd2a87fed 100644 (file)
@@ -115,6 +115,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_RND1,             "rnd1"             },
     { LLM_ARCH_PANGU_EMBED,      "pangu-embedded"   },
     { LLM_ARCH_MISTRAL3,         "mistral3"         },
+    { LLM_ARCH_MIMO2,            "mimo2"           },
     { LLM_ARCH_LLAMA_EMBED,      "llama-embed"      },
     { LLM_ARCH_UNKNOWN,          "(unknown)"        },
 };
@@ -2190,6 +2191,27 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
                 LLM_TENSOR_VISEXP_FFN_DOWN,
                 LLM_TENSOR_VISEXP_FFN_UP,
             };
+        case LLM_ARCH_MIMO2:
+            return {
+                LLM_TENSOR_TOKEN_EMBD,
+                LLM_TENSOR_OUTPUT_NORM,
+                LLM_TENSOR_OUTPUT,
+                LLM_TENSOR_ATTN_NORM,
+                LLM_TENSOR_ATTN_Q,
+                LLM_TENSOR_ATTN_K,
+                LLM_TENSOR_ATTN_V,
+                LLM_TENSOR_ATTN_SINKS,
+                LLM_TENSOR_ATTN_OUT,
+                LLM_TENSOR_FFN_NORM,
+                LLM_TENSOR_FFN_GATE,
+                LLM_TENSOR_FFN_DOWN,
+                LLM_TENSOR_FFN_UP,
+                LLM_TENSOR_FFN_GATE_INP,
+                LLM_TENSOR_FFN_GATE_EXPS,
+                LLM_TENSOR_FFN_DOWN_EXPS,
+                LLM_TENSOR_FFN_UP_EXPS,
+                LLM_TENSOR_FFN_EXP_PROBS_B,
+            };
         case LLM_ARCH_GPTJ:
         case LLM_ARCH_UNKNOWN:
             return {
index 433ee4bc18fb9d307fc23ffd6e191dfe945d26e2..27bdedc83c802f0726daebae89a09c34d53ca5ee 100644 (file)
@@ -119,6 +119,7 @@ enum llm_arch {
     LLM_ARCH_RND1,
     LLM_ARCH_PANGU_EMBED,
     LLM_ARCH_MISTRAL3,
+    LLM_ARCH_MIMO2,
     LLM_ARCH_LLAMA_EMBED,
     LLM_ARCH_UNKNOWN,
 };
index f6e95b5d2a6ecbc997f56f822d9fcb16c30794ed..42def73f06f772ea546dc27f59f2b226b19f495e 100644 (file)
@@ -123,10 +123,11 @@ struct llama_hparams {
     llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
     // the size of the sliding window (0 - no SWA)
     uint32_t n_swa = 0;
-    // if swa_layers[il] == true, then layer il is SWA
-    // if swa_layers[il] == false, then layer il is dense (i.e. non-SWA)
+    // if swa_layers[il] == 1, then layer il is SWA
+    // if swa_layers[il] == 0, then layer il is dense (i.e. non-SWA)
     // by default, all layers are dense
-    std::array<bool, LLAMA_MAX_LAYERS> swa_layers;
+    // note: using uint32_t type for compatibility reason
+    std::array<uint32_t, LLAMA_MAX_LAYERS> swa_layers;
 
     // for State Space Models
     uint32_t ssm_d_conv  = 0;
index 9fada915d7d1729e296690ff604abb3c075c166d..69075742c915111f898d791f52509331c69903a4 100644 (file)
@@ -130,6 +130,7 @@ const char * llm_type_name(llm_type type) {
         case LLM_TYPE_230B_A10B:     return "230B.A10B";
         case LLM_TYPE_235B_A22B:     return "235B.A22B";
         case LLM_TYPE_300B_A47B:     return "300B.A47B";
+        case LLM_TYPE_310B_A15B:     return "310B.A15B";
         case LLM_TYPE_355B_A32B:     return "355B.A32B";
         case LLM_TYPE_E2B:           return "E2B";
         case LLM_TYPE_E4B:           return "E4B";
@@ -2339,6 +2340,22 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     default: type = LLM_TYPE_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_MIMO2:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+                hparams.swa_type = LLAMA_SWA_TYPE_STANDARD;
+
+                ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
+                ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW,   hparams.n_swa);
+                ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA,         hparams.rope_freq_base_train_swa);
+                ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, hparams.swa_layers, hparams.n_layer);
+
+                switch (hparams.n_layer) {
+                    case 48: type = LLM_TYPE_310B_A15B; break;
+                    default: type = LLM_TYPE_UNKNOWN;
+                }
+            } break;
         default: throw std::runtime_error("unsupported model architecture");
     }
 
@@ -6648,6 +6665,44 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         layer.ffn_down_shexp     = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP,     "weight", i), { hparams.n_ff_shexp, n_embd }, 0);
                     }
                 } break;
+            case LLM_ARCH_MIMO2:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+                        uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i);
+                        uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i);
+                        uint32_t n_head = hparams.n_head(i);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_v * n_head, n_embd }, 0);
+
+                        layer.attn_norm  = create_tensor(tn(LLM_TENSOR_ATTN_NORM,  "weight", i), {n_embd}, 0);
+                        layer.attn_sinks = create_tensor(tn(LLM_TENSOR_ATTN_SINKS, "weight", i), {n_head}, TENSOR_NOT_REQUIRED);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        // non-MoE branch
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, TENSOR_NOT_REQUIRED);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, TENSOR_NOT_REQUIRED);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, TENSOR_NOT_REQUIRED);
+
+                        // MoE branch
+                        int64_t n_ff_exp = hparams.n_ff_exp;
+                        layer.ffn_gate_inp  = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP,  "weight", i), {n_embd, n_expert}, TENSOR_NOT_REQUIRED);
+                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp,   n_expert}, TENSOR_NOT_REQUIRED);
+                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert}, TENSOR_NOT_REQUIRED);
+                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {n_embd, n_ff_exp,   n_expert}, TENSOR_NOT_REQUIRED);
+                        layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED);
+                    }
+                } break;
             default:
                 throw std::runtime_error("unknown architecture");
         }
@@ -7710,6 +7765,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
             {
                 llm = std::make_unique<llm_build_mistral3>(*this, params);
             } break;
+        case LLM_ARCH_MIMO2:
+            {
+                llm = std::make_unique<llm_build_mimo2_iswa>(*this, params);
+            } break;
         default:
             GGML_ABORT("fatal error");
     }
@@ -7940,6 +7999,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_PANGU_EMBED:
         case LLM_ARCH_AFMOE:
         case LLM_ARCH_QWEN3NEXT:
+        case LLM_ARCH_MIMO2:
             return LLAMA_ROPE_TYPE_NEOX;
 
         case LLM_ARCH_QWEN2VL:
index 7f560d462fa474049b997a2a662b58dc35b56e0d..9c00eec75f51ed97e2f74c466672fb88fd976db6 100644 (file)
@@ -123,6 +123,7 @@ enum llm_type {
     LLM_TYPE_230B_A10B, // Minimax M2
     LLM_TYPE_235B_A22B,
     LLM_TYPE_300B_A47B, // Ernie MoE big
+    LLM_TYPE_310B_A15B, // /MiMo-V2-Flash
     LLM_TYPE_355B_A32B, // GLM-4.5
     LLM_TYPE_E2B,
     LLM_TYPE_E4B,
diff --git a/src/models/mimo2-iswa.cpp b/src/models/mimo2-iswa.cpp
new file mode 100644 (file)
index 0000000..edc87cc
--- /dev/null
@@ -0,0 +1,123 @@
+
+#include "models.h"
+
+llm_build_mimo2_iswa::llm_build_mimo2_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+    ggml_tensor * cur;
+    ggml_tensor * inpL;
+
+    inpL = build_inp_embd(model.tok_embd);
+
+    ggml_tensor * inp_pos = build_inp_pos();
+    auto * inp_attn = build_attn_inp_kv_iswa();
+    ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+    for (int il = 0; il < n_layer; ++il) {
+        ggml_tensor * inpSA = inpL;
+
+        uint32_t n_head_l    = hparams.n_head(il);
+        uint32_t n_head_kv_l = hparams.n_head_kv(il);
+        const float freq_base_l  = model.get_rope_freq_base(cparams, il);
+        const float freq_scale_l = model.get_rope_freq_scale(cparams, il);
+
+        cur = inpL;
+
+        // self_attention
+        {
+            cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // compute Q and K and RoPE them
+            ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+            cb(Qcur, "Qcur", il);
+
+            ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+            cb(Kcur, "Kcur", il);
+
+            ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+            cb(Vcur, "Vcur", il);
+
+            Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head_l,    n_tokens);
+            Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv_l, n_tokens);
+            Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv_l, n_tokens);
+
+            Qcur = ggml_rope_ext(
+                ctx0, Qcur, inp_pos, nullptr,
+                n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+                ext_factor, attn_factor, beta_fast, beta_slow
+                );
+
+            Kcur = ggml_rope_ext(
+                ctx0, Kcur, inp_pos, nullptr,
+                n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l,
+                ext_factor, attn_factor, beta_fast, beta_slow
+                );
+
+            cb(Qcur, "Qcur", il);
+            cb(Kcur, "Kcur", il);
+            cb(Vcur, "Vcur", il);
+
+            ggml_tensor * sinks = model.layers[il].attn_sinks;
+
+            cur = build_attn(inp_attn,
+                    model.layers[il].wo, NULL,
+                    Qcur, Kcur, Vcur, nullptr, sinks, nullptr, 1.0f/sqrtf(float(n_embd_head_k)), il);
+        }
+
+        if (il == n_layer - 1 && inp_out_ids) {
+            cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+            inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+        }
+
+        ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+        cb(ffn_inp, "ffn_inp", il);
+
+        cur = build_norm(ffn_inp,
+                model.layers[il].ffn_norm, NULL,
+                LLM_NORM_RMS, il);
+        cb(cur, "ffn_norm", il);
+
+        // feed-forward network
+        if (model.layers[il].ffn_gate_inp == nullptr) {
+            // dense branch
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                    model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
+                    model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                    NULL,
+                    LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+        } else {
+            // MoE branch
+            cur = build_moe_ffn(cur, model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps,
+                                model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps,
+                                model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, true, false,
+                                0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID, il);
+            cb(cur, "ffn_moe_out", il);
+        }
+
+        cur = ggml_add(ctx0, cur, ffn_inp);
+
+        cur = build_cvec(cur, il);
+        cb(cur, "l_out", il);
+
+        // input for next layer
+        inpL = cur;
+    }
+
+    cur = inpL;
+
+    cur = build_norm(cur,
+            model.output_norm, NULL,
+            LLM_NORM_RMS, -1);
+
+    cb(cur, "result_norm", -1);
+    res->t_embd = cur;
+
+    // lm_head
+    cur = build_lora_mm(model.output, cur);
+
+    cb(cur, "result_output", -1);
+    res->t_logits = cur;
+
+    ggml_build_forward_expand(gf, cur);
+}
index fca505b30ab98acbb639c7612947e3e30dac43b9..dd0e286eda2386b450dc73697c99a6995c326753 100644 (file)
@@ -316,6 +316,10 @@ struct llm_build_mamba : public llm_graph_context_mamba {
     llm_build_mamba(const llama_model & model, const llm_graph_params & params);
 };
 
+struct llm_build_mimo2_iswa : public llm_graph_context {
+    llm_build_mimo2_iswa(const llama_model & model, const llm_graph_params & params);
+};
+
 struct llm_build_minicpm3 : public llm_graph_context {
     llm_build_minicpm3(const llama_model & model, const llm_graph_params & params);
 };