]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
model : add dots.llm1 architecture support (#14044) (#14118)
authorMikko Juola <redacted>
Sun, 15 Jun 2025 07:52:06 +0000 (00:52 -0700)
committerGitHub <redacted>
Sun, 15 Jun 2025 07:52:06 +0000 (09:52 +0200)
Adds:

* Dots1Model to convert_hf_to_gguf.py

* Computation graph code to llama-model.cpp

* Chat template to llama-chat.cpp to detect this model's template.

---

The model is called "dots.llm1" (I decided to shorten it to dots1 or
DOTS1 in the code generally) architecture.

The only models that exist as of writing of this commit that follow this
architecture are "dots.llm1.inst" and "dots.llm1.base" from here:

* https://huggingface.co/rednote-hilab/dots.llm1.inst

* https://huggingface.co/rednote-hilab/dots.llm1.base

The model architecture is a combination of Qwen and Deepseek parts, as
seen here:

https://github.com/huggingface/transformers/blob/ffe12627b4e84489d2ab91dd0ec00614855edc79/src/transformers/models/dots1/modular_dots1.py

convert_hf_to_gguf.py
gguf-py/gguf/constants.py
gguf-py/gguf/tensor_mapping.py
src/llama-arch.cpp
src/llama-arch.h
src/llama-chat.cpp
src/llama-chat.h
src/llama-model.cpp
src/llama-model.h

index 173a103badc6019658db45312d469bad6d4ad31c..cff72c85fab6971f8916366f8ad35bef78a99142 100755 (executable)
@@ -5262,6 +5262,34 @@ class DeepseekV2Model(TextModel):
                 raise ValueError(f"Unprocessed experts: {experts}")
 
 
+@ModelBase.register("Dots1ForCausalLM")
+class Dots1Model(Qwen2MoeModel):
+    model_arch = gguf.MODEL_ARCH.DOTS1
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.hparams["num_experts"] = self.hparams["n_routed_experts"]
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        self.gguf_writer.add_leading_dense_block_count(self.hparams["first_k_dense_replace"])
+        self.gguf_writer.add_expert_shared_count(self.hparams["n_shared_experts"])
+        self.gguf_writer.add_expert_weights_scale(self.hparams["routed_scaling_factor"])
+        self.gguf_writer.add_expert_weights_norm(self.hparams["norm_topk_prob"])
+
+        if self.hparams["scoring_func"] == "noaux_tc":
+            self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
+        else:
+            raise ValueError(f"Unsupported scoring_func value: {self.hparams['scoring_func']}")
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
+        if name.endswith("e_score_correction_bias"):
+            name = name.replace("e_score_correction_bias", "e_score_correction.bias")
+        if "shared_experts" in name:
+            return [(self.map_tensor_name(name), data_torch)]
+        return super().modify_tensors(data_torch, name, bid)
+
+
 @ModelBase.register("PLMForCausalLM")
 class PLMModel(TextModel):
     model_arch = gguf.MODEL_ARCH.PLM
index 3ee2b2064e1b429b74a0b8e48a38092d3c84f0ef..8de2f7a531967b325a4237879f353ab6e0477ad8 100644 (file)
@@ -343,6 +343,7 @@ class MODEL_ARCH(IntEnum):
     WAVTOKENIZER_DEC = auto()
     PLM              = auto()
     BAILINGMOE       = auto()
+    DOTS1            = auto()
 
 
 class VISION_PROJECTOR_TYPE(IntEnum):
@@ -623,6 +624,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec",
     MODEL_ARCH.PLM:              "plm",
     MODEL_ARCH.BAILINGMOE:       "bailingmoe",
+    MODEL_ARCH.DOTS1:            "dots1"
 }
 
 VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@@ -2044,6 +2046,30 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_DOWN_SHEXP,
         MODEL_TENSOR.FFN_UP_SHEXP,
     ],
+    MODEL_ARCH.DOTS1: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_Q_NORM,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_K_NORM,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.FFN_EXP_PROBS_B,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_GATE,
+        MODEL_TENSOR.FFN_GATE_EXP,
+        MODEL_TENSOR.FFN_GATE_INP,
+        MODEL_TENSOR.FFN_GATE_SHEXP,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_DOWN_EXP,
+        MODEL_TENSOR.FFN_DOWN_SHEXP,
+        MODEL_TENSOR.FFN_UP,
+        MODEL_TENSOR.FFN_UP_EXP,
+        MODEL_TENSOR.FFN_UP_SHEXP,
+    ],
     # TODO
 }
 
index 439fc1afeeb0cb9cbd53b2912959b2f487f20581..5e3f01754bf0794f9b7b7f79b127aa066ef95d43 100644 (file)
@@ -305,7 +305,7 @@ class TensorNameMap:
         ),
 
         MODEL_TENSOR.FFN_EXP_PROBS_B: (
-            "model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3
+            "model.layers.{bid}.mlp.gate.e_score_correction", # deepseek-v3 dots1
         ),
 
         # Feed-forward up
index 43fa60a8070b765f09dff5ef73124bb342cb0d52..f8f76eedd4fa687b6c31eddf4f4bf3a5512d9659 100644 (file)
@@ -72,6 +72,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" },
     { LLM_ARCH_PLM,              "plm"              },
     { LLM_ARCH_BAILINGMOE,       "bailingmoe"       },
+    { LLM_ARCH_DOTS1,            "dots1"            },
     { LLM_ARCH_UNKNOWN,          "(unknown)"        },
 };
 
@@ -1555,6 +1556,34 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_FFN_UP_SHEXP,       "blk.%d.ffn_up_shexp" },
         },
     },
+    {
+        LLM_ARCH_DOTS1,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,         "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,        "output_norm" },
+            { LLM_TENSOR_OUTPUT,             "output" },
+            { LLM_TENSOR_ATTN_NORM,          "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,             "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_Q_NORM,        "blk.%d.attn_q_norm" },
+            { LLM_TENSOR_ATTN_K,             "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_K_NORM,        "blk.%d.attn_k_norm" },
+            { LLM_TENSOR_ATTN_V,             "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,           "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_NORM,           "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,           "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_UP,             "blk.%d.ffn_up" },
+            { LLM_TENSOR_FFN_DOWN,           "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_GATE_INP,       "blk.%d.ffn_gate_inp" },
+            { LLM_TENSOR_FFN_GATE_EXPS,      "blk.%d.ffn_gate_exps" },
+            { LLM_TENSOR_FFN_DOWN_EXPS,      "blk.%d.ffn_down_exps" },
+            { LLM_TENSOR_FFN_UP_EXPS,        "blk.%d.ffn_up_exps" },
+            { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
+            { LLM_TENSOR_FFN_GATE_SHEXP,     "blk.%d.ffn_gate_shexp" },
+            { LLM_TENSOR_FFN_DOWN_SHEXP,     "blk.%d.ffn_down_shexp" },
+            { LLM_TENSOR_FFN_UP_SHEXP,       "blk.%d.ffn_up_shexp" },
+            { LLM_TENSOR_FFN_EXP_PROBS_B,    "blk.%d.exp_probs_b" },
+        }
+    },
     {
         LLM_ARCH_UNKNOWN,
         {
index f3825528aefdb93cf9f571fb479de36fe5aadee4..18f6d6b94f13716119d3f069e28c24306dd1b4a4 100644 (file)
@@ -76,6 +76,7 @@ enum llm_arch {
     LLM_ARCH_WAVTOKENIZER_DEC,
     LLM_ARCH_PLM,
     LLM_ARCH_BAILINGMOE,
+    LLM_ARCH_DOTS1,
     LLM_ARCH_UNKNOWN,
 };
 
index d12743e6b9a0cf1bc12ebd698b71cef0a92db62b..bc4fa05a74ef470796099125835dcc32b0ad52d9 100644 (file)
@@ -183,6 +183,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
         return LLM_CHAT_TEMPLATE_BAILING;
     } else if (tmpl_contains("<|header_start|>") && tmpl_contains("<|header_end|>")) {
         return LLM_CHAT_TEMPLATE_LLAMA4;
+    } else if (tmpl_contains("<|endofuserprompt|>")) {
+        return LLM_CHAT_TEMPLATE_DOTS1;
     }
     return LLM_CHAT_TEMPLATE_UNKNOWN;
 }
@@ -643,6 +645,21 @@ int32_t llm_chat_apply_template(
         if (add_ass) {
             ss << "Assistant:";
         }
+    } else if (tmpl == LLM_CHAT_TEMPLATE_DOTS1) {
+        // dots.llm1.inst (DOTS1)
+        for (auto message : chat) {
+            std::string role(message->role);
+            if (role == "system") {
+                ss << "<|system|>" << message->content << "<|endofsystem|>";
+            } else if (role == "user") {
+                ss << "<|userprompt|>" << message->content << "<|endofuserprompt|>";
+            } else {
+                ss << "<|response|>" << message->content << "<|endofresponse|>";
+            }
+        }
+        if (add_ass) {
+            ss << "<|response|>";
+        }
     } else {
         // template not supported
         return -1;
index db24ade21e2ad76671732d3c6b625ea76357120e..38800010ae48b5da3474fb47ee4c490d693db68f 100644 (file)
@@ -43,6 +43,7 @@ enum llm_chat_template {
     LLM_CHAT_TEMPLATE_BAILING,
     LLM_CHAT_TEMPLATE_LLAMA4,
     LLM_CHAT_TEMPLATE_SMOLVLM,
+    LLM_CHAT_TEMPLATE_DOTS1,
     LLM_CHAT_TEMPLATE_UNKNOWN,
 };
 
index c64bf9de939f4395335e01f9fd0d1a766ab3993c..fdd5fefd6e77863d60510d9ffc68497c7122a9b8 100644 (file)
@@ -80,6 +80,7 @@ const char * llm_type_name(llm_type type) {
         case LLM_TYPE_40B:           return "40B";
         case LLM_TYPE_65B:           return "65B";
         case LLM_TYPE_70B:           return "70B";
+        case LLM_TYPE_142B:          return "142B";
         case LLM_TYPE_236B:          return "236B";
         case LLM_TYPE_290B:          return "290B";
         case LLM_TYPE_314B:          return "314B";
@@ -1444,6 +1445,20 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     default: type = LLM_TYPE_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_DOTS1:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+                ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT,   hparams.n_layer_dense_lead);
+                ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH,  hparams.n_ff_exp);
+                ml.get_key(LLM_KV_EXPERT_SHARED_COUNT,         hparams.n_expert_shared);
+                ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE,        hparams.expert_weights_scale);
+                ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM,         hparams.expert_weights_norm, false);
+                ml.get_key(LLM_KV_EXPERT_GATING_FUNC,          hparams.expert_gating_func, false);
+                switch (hparams.n_layer) {
+                    case 62: type = LLM_TYPE_142B; break;
+                    default: type = LLM_TYPE_UNKNOWN;
+                }
+            } break;
         default: throw std::runtime_error("unsupported model architecture");
     }
 
@@ -4123,6 +4138,58 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         layer.ffn_up_shexp   = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
                     }
                 } break;
+            case LLM_ARCH_DOTS1:
+                {
+                    const int64_t n_ff_exp        = hparams.n_ff_exp;
+                    const int64_t n_expert_shared = hparams.n_expert_shared;
+
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, 0);
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+
+                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
+                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        if (i < (int) hparams.n_layer_dense_lead) {
+                            layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                            layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                            layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                        } else {
+                            layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
+                            layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED);
+
+                            if (n_expert == 0) {
+                                throw std::runtime_error("n_expert must be > 0");
+                            }
+                            if (n_expert_used == 0) {
+                                throw std::runtime_error("n_expert_used must be > 0");
+                            }
+
+                            // MoE branch
+                            layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
+                            layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert}, 0);
+                            layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
+
+                            // Shared expert branch
+                            layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
+                            layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {        n_ff_exp * n_expert_shared, n_embd}, 0);
+                            layer.ffn_up_shexp   = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP,   "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0);
+                        }
+                    }
+                } break;
             default:
                 throw std::runtime_error("unknown architecture");
         }
@@ -13194,6 +13261,156 @@ struct llm_build_bailingmoe : public llm_graph_context {
     }
 };
 
+struct llm_build_dots1 : public llm_graph_context {
+    llm_build_dots1(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified();
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self_attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+                Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
+                cb(Qcur, "Qcur_normed", il);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+
+                Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
+                cb(Kcur, "Kcur_normed", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // MoE branch
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            if ((uint32_t) il < hparams.n_layer_dense_lead) {
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   NULL, NULL,
+                        model.layers[il].ffn_gate, NULL, NULL,
+                        model.layers[il].ffn_down, NULL, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            } else {
+                ggml_tensor * moe_out =
+                    build_moe_ffn(cur,
+                            model.layers[il].ffn_gate_inp,
+                            model.layers[il].ffn_up_exps,
+                            model.layers[il].ffn_gate_exps,
+                            model.layers[il].ffn_down_exps,
+                            model.layers[il].ffn_exp_probs_b,
+                            n_expert, n_expert_used,
+                            LLM_FFN_SILU, hparams.expert_weights_norm,
+                            true, hparams.expert_weights_scale,
+                            (llama_expert_gating_func_type) hparams.expert_gating_func,
+                            il);
+                cb(moe_out, "ffn_moe_out", il);
+
+                {
+                    ggml_tensor * ffn_shexp = build_ffn(cur,
+                            model.layers[il].ffn_up_shexp,   NULL, NULL,
+                            model.layers[il].ffn_gate_shexp, NULL, NULL,
+                            model.layers[il].ffn_down_shexp, NULL, NULL,
+                            NULL,
+                            LLM_FFN_SILU, LLM_FFN_PAR, il);
+                    cb(ffn_shexp, "ffn_shexp", il);
+
+                    cur = ggml_add(ctx0, moe_out, ffn_shexp);
+                    cb(cur, "ffn_out", il);
+                }
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
 llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
     llama_memory_i * res;
 
@@ -13532,6 +13749,10 @@ llm_graph_result_ptr llama_model::build_graph(
             {
                 llm = std::make_unique<llm_build_bailingmoe>(*this, params, gf);
             } break;
+        case LLM_ARCH_DOTS1:
+            {
+                llm = std::make_unique<llm_build_dots1>(*this, params, gf);
+            } break;
         default:
             GGML_ABORT("fatal error");
     }
@@ -13714,6 +13935,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_NEMOTRON:
         case LLM_ARCH_EXAONE:
         case LLM_ARCH_MINICPM3:
+        case LLM_ARCH_DOTS1:
             return LLAMA_ROPE_TYPE_NEOX;
 
         case LLM_ARCH_QWEN2VL:
index 18b714620bbcf899e23828eb9653fcb80b728e49..06e6c687943cc23e615bd1f49f773347d2b4247b 100644 (file)
@@ -73,6 +73,7 @@ enum llm_type {
     LLM_TYPE_40B,
     LLM_TYPE_65B,
     LLM_TYPE_70B,
+    LLM_TYPE_142B,
     LLM_TYPE_236B,
     LLM_TYPE_290B,
     LLM_TYPE_314B,