]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
model : add support for SmallThinker series (#14898)
authorDongliang Wei <redacted>
Mon, 28 Jul 2025 11:47:00 +0000 (19:47 +0800)
committerGitHub <redacted>
Mon, 28 Jul 2025 11:47:00 +0000 (13:47 +0200)
* support smallthinker

* support 20b softmax, 4b no sliding window

* new build_moe_ffn_from_probs, and can run 4b

* fix 4b rope bug

* fix python type check

* remove is_moe judge

* remove set_dense_start_swa_pattern function and modify set_swa_pattern function

* trim trailing whitespace

* remove get_vocab_base of SmallThinkerModel in convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* better whitespace

Apply suggestions from code review

Co-authored-by: Sigbjørn Skjæret <redacted>
* use GGML_ASSERT for expert count validation

Co-authored-by: Sigbjørn Skjæret <redacted>
* Improve null pointer check for probs

Co-authored-by: Sigbjørn Skjæret <redacted>
* use template parameter for SWA attention logic

* better whitespace

Co-authored-by: Georgi Gerganov <redacted>
* move the creation of inp_out_ids before the layer loop

* remove redundant judge for probs

---------

Co-authored-by: Sigbjørn Skjæret <redacted>
Co-authored-by: Georgi Gerganov <redacted>
convert_hf_to_gguf.py
gguf-py/gguf/constants.py
gguf-py/gguf/tensor_mapping.py
src/llama-arch.cpp
src/llama-arch.h
src/llama-graph.cpp
src/llama-graph.h
src/llama-hparams.cpp
src/llama-hparams.h
src/llama-model.cpp

index 2c51b6c42d4c595f779ad22140abd73526af7956..9461206cf1967fc059926e3c31dd06255f73c3d1 100755 (executable)
@@ -7589,6 +7589,88 @@ class LFM2Model(TextModel):
         return [(self.map_tensor_name(name), data_torch)]
 
 
+@ModelBase.register("SmallThinkerForCausalLM")
+class SmallThinkerModel(TextModel):
+    model_arch = gguf.MODEL_ARCH.SMALLTHINKER
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        if (n_experts := self.hparams.get("num_experts", self.hparams.get("moe_num_primary_experts"))) is not None:
+            self.gguf_writer.add_expert_count(n_experts)
+        if (n_experts_used := self.hparams.get("num_experts_per_tok", self.hparams.get("moe_num_active_primary_experts"))) is not None:
+            self.gguf_writer.add_expert_used_count(n_experts_used)
+        if (moe_intermediate_size := self.hparams.get("moe_ffn_hidden_size")) is not None:
+            self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
+            self.gguf_writer.add_feed_forward_length(moe_intermediate_size)
+            logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}")
+        if (self.hparams.get('moe_primary_router_apply_softmax')):
+            self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
+        else:
+            self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
+        # YaRN is not enabled by default
+        # To enable it, please refer to this guide: https://huggingface.co/Qwen/Qwen3-30B-A3B#processing-long-texts
+        rope_scaling = self.hparams.get("rope_scaling") or {}
+        if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
+            self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
+            self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
+            self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
+
+        sliding_window_layout = self.hparams.get("sliding_window_layout")
+        if sliding_window_layout:
+            for i in sliding_window_layout:
+                if i != 0:
+                    sliding_window = self.hparams.get("sliding_window_size")
+                    if sliding_window:
+                        self.gguf_writer.add_sliding_window(sliding_window)
+                    break
+
+    _experts: list[dict[str, Tensor]] | None = None
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        # process the experts separately
+        if name.find("experts") != -1:
+            n_experts = self.hparams.get("num_experts", self.hparams.get("moe_num_primary_experts"))
+            assert bid is not None
+
+            if self._experts is None:
+                self._experts = [{} for _ in range(self.block_count)]
+
+            self._experts[bid][name] = data_torch
+
+            if len(self._experts[bid]) >= n_experts * 3:
+                tensors: list[tuple[str, Tensor]] = []
+
+                # merge the experts into a single 3d tensor
+                for w_name in ["down", "gate", "up"]:
+                    datas: list[Tensor] = []
+
+                    for xid in range(n_experts):
+                        ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{w_name}.weight"
+                        datas.append(self._experts[bid][ename])
+                        del self._experts[bid][ename]
+
+                    data_torch = torch.stack(datas, dim=0)
+
+                    merged_name = f"model.layers.{bid}.block_sparse_moe.experts.{w_name}.weight"
+
+                    new_name = self.map_tensor_name(merged_name)
+
+                    tensors.append((new_name, data_torch))
+                return tensors
+            else:
+                return []
+
+        return [(self.map_tensor_name(name), data_torch)]
+
+    def prepare_tensors(self):
+        super().prepare_tensors()
+
+        if self._experts is not None:
+            # flatten `list[dict[str, Tensor]]` into `list[str]`
+            experts = [k for d in self._experts for k in d.keys()]
+            if len(experts) > 0:
+                raise ValueError(f"Unprocessed experts: {experts}")
+
 ###### CONVERSION LOGIC ######
 
 
index 680210db7e9d5ce616e236ebe33df4754ee5c55d..7ac1a1971938b0f214e4600658a292c345cf873e 100644 (file)
@@ -376,6 +376,7 @@ class MODEL_ARCH(IntEnum):
     SMOLLM3          = auto()
     LFM2             = auto()
     DREAM            = auto()
+    SMALLTHINKER     = auto()
 
 
 class VISION_PROJECTOR_TYPE(IntEnum):
@@ -695,6 +696,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.SMOLLM3:          "smollm3",
     MODEL_ARCH.LFM2:             "lfm2",
     MODEL_ARCH.DREAM:            "dream",
+    MODEL_ARCH.SMALLTHINKER:     "smallthinker",
 }
 
 VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@@ -2483,6 +2485,24 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.ATTN_V,
         MODEL_TENSOR.ATTN_OUT,
     ],
+    MODEL_ARCH.SMALLTHINKER: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_GATE,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+        MODEL_TENSOR.FFN_GATE_INP,
+        MODEL_TENSOR.FFN_GATE_EXP,
+        MODEL_TENSOR.FFN_DOWN_EXP,
+        MODEL_TENSOR.FFN_UP_EXP,
+    ],
     # TODO
 }
 
index 7fbda422f0fe93d9f80883ac9d977c4e686941cf..bfd4fd37a3f6868a44a5271d607dc7726c5c9bb3 100644 (file)
@@ -317,6 +317,7 @@ class TensorNameMap:
             "model.layers.{bid}.feed_forward.router",           # llama4 jamba
             "encoder.layers.{bid}.mlp.router.layer",            # nomic-bert-moe
             "model.layers.{bid}.mlp.gate.wg",                   # hunyuan
+            "model.layers.{bid}.block_sparse_moe.primary_router", # smallthinker
         ),
 
         MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
@@ -362,6 +363,7 @@ class TensorNameMap:
             "transformer.h.{bid}.mlp.c_fc_1",                         # exaone
             "model.layers.{bid}.feed_forward.up_proj",                # llama4 jamba granite-hybrid
             "transformer_encoder.{bid}.ffn.w12",                      # neobert
+            "model.layers.{bid}.block_sparse_moe.up",                 # smallthinker
         ),
 
         MODEL_TENSOR.FFN_UP_EXP: (
@@ -372,6 +374,7 @@ class TensorNameMap:
             "model.layers.{bid}.block_sparse_moe.experts.w3",       # phimoe (merged)
             "model.layers.{bid}.feed_forward.experts.up_proj",      # llama4
             "encoder.layers.{bid}.mlp.experts.mlp.w1",              # nomic-bert-moe
+            "model.layers.{bid}.block_sparse_moe.experts.up", # smallthinker
         ),
 
         MODEL_TENSOR.FFN_UP_SHEXP: (
@@ -401,6 +404,7 @@ class TensorNameMap:
             "model.layers.{bid}.residual_mlp.w1",         # arctic
             "transformer.h.{bid}.mlp.c_fc_0",             # exaone
             "model.layers.{bid}.feed_forward.gate_proj",  # llama4 jamba granite-hybrid
+            "model.layers.{bid}.block_sparse_moe.gate",   # smallthinker
         ),
 
         MODEL_TENSOR.FFN_GATE_EXP: (
@@ -410,6 +414,7 @@ class TensorNameMap:
             "model.layers.{bid}.mlp.experts.gate_proj",                 # qwen2moe olmoe (merged) ernie4.5-moe
             "model.layers.{bid}.block_sparse_moe.experts.w1",           # phimoe (merged)
             "model.layers.{bid}.feed_forward.experts.gate_proj",        # llama4
+            "model.layers.{bid}.block_sparse_moe.experts.gate",         # smallthinker
         ),
 
         MODEL_TENSOR.FFN_GATE_SHEXP: (
@@ -448,6 +453,7 @@ class TensorNameMap:
             "model.layers.h.{bid}.mlp.c_proj",                        # exaone
             "model.layers.{bid}.feed_forward.down_proj",              # llama4 jamba granite-hybrid
             "transformer_encoder.{bid}.ffn.w3",                       # neobert
+            "model.layers.{bid}.block_sparse_moe.down",               # smallthinker
         ),
 
         MODEL_TENSOR.FFN_DOWN_EXP: (
@@ -459,6 +465,7 @@ class TensorNameMap:
             "model.layers.{bid}.block_sparse_moe.experts.w2",       # phimoe (merged)
             "model.layers.{bid}.feed_forward.experts.down_proj",    # llama4
             "encoder.layers.{bid}.mlp.experts.mlp.w2",              # nomic-bert-moe
+            "model.layers.{bid}.block_sparse_moe.experts.down",     # smallthinker
         ),
 
         MODEL_TENSOR.FFN_DOWN_SHEXP: (
index 062a99776781f6146783b4670fe14df9ebde32d0..dbf977443ae85f2870ff2d7eb134ca1b615385e1 100644 (file)
@@ -88,6 +88,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_SMOLLM3,          "smollm3"          },
     { LLM_ARCH_LFM2,             "lfm2"             },
     { LLM_ARCH_DREAM,            "dream"            },
+    { LLM_ARCH_SMALLTHINKER,     "smallthinker"     },
     { LLM_ARCH_UNKNOWN,          "(unknown)"        },
 };
 
@@ -1933,6 +1934,27 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_TOKEN_EMBD_NORM,   "token_embd_norm" },
         }
     },
+    {
+        LLM_ARCH_SMALLTHINKER,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,         "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,        "output_norm" },
+            { LLM_TENSOR_OUTPUT,             "output" },
+            { LLM_TENSOR_ATTN_NORM,          "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,             "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,             "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,             "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,           "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_NORM,           "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,           "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,           "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,             "blk.%d.ffn_up" },
+            { LLM_TENSOR_FFN_GATE_INP,       "blk.%d.ffn_gate_inp" },
+            { LLM_TENSOR_FFN_GATE_EXPS,      "blk.%d.ffn_gate_exps" },
+            { LLM_TENSOR_FFN_DOWN_EXPS,      "blk.%d.ffn_down_exps" },
+            { LLM_TENSOR_FFN_UP_EXPS,        "blk.%d.ffn_up_exps" }
+        },
+    },
     {
         LLM_ARCH_DREAM,
         {
index d09b7d7810b03a2cebb5abc463ca9744805fa100..8267a8d3aa49128c932dcf44098ff544a1c3a3f2 100644 (file)
@@ -92,6 +92,7 @@ enum llm_arch {
     LLM_ARCH_SMOLLM3,
     LLM_ARCH_LFM2,
     LLM_ARCH_DREAM,
+    LLM_ARCH_SMALLTHINKER,
     LLM_ARCH_UNKNOWN,
 };
 
index b63a41053b488b23025b73df495b1f2f2e8c43d3..1b9cc4aec0632f93b97377e114d6ccd60f17652d 100644 (file)
@@ -938,6 +938,100 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
     return moe_out;
 }
 
+ggml_tensor * llm_graph_context::build_moe_ffn_from_probs(
+         ggml_tensor * cur,
+         ggml_tensor * probs,
+         ggml_tensor * up_exps,
+         ggml_tensor * gate_exps,
+         ggml_tensor * down_exps,
+         ggml_tensor * exp_probs_b,
+             int64_t   n_expert,
+             int64_t   n_expert_used,
+             llama_expert_gating_func_type gating_op,
+                 int   il) const {
+    const int64_t n_embd   = cur->ne[0];
+    const int64_t n_tokens = cur->ne[1];
+
+    // add experts selection bias - introduced in DeepSeek V3
+    // leave probs unbiased as it's later used to get expert weights
+    ggml_tensor * selection_probs = probs;
+    if (exp_probs_b != nullptr) {
+        selection_probs = ggml_add(ctx0, probs, exp_probs_b);
+        cb(selection_probs, "ffn_moe_probs_biased", il);
+    }
+
+    // select experts
+    ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
+    cb(selected_experts->src[0], "ffn_moe_argsort", il);
+    cb(selected_experts, "ffn_moe_topk", il);
+
+    ggml_tensor * weights = ggml_get_rows(ctx0,
+            ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
+    cb(weights, "ffn_moe_weights", il);
+
+    weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
+     if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX) {
+        weights = ggml_soft_max(ctx0, weights);
+    } else {
+        weights = ggml_sigmoid(ctx0, weights);
+        ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights); // [1, n_tokens]
+        cb(weights_sum, "ffn_moe_weights_sum", il);
+
+        weights = ggml_div(ctx0, weights, weights_sum); // [n_expert_used, n_tokens]
+        cb(weights, "ffn_moe_weights_norm", il);
+    }
+
+    weights = ggml_reshape_3d(ctx0, weights, 1, n_expert_used, n_tokens);
+
+    cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
+
+    ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
+    cb(up, "ffn_moe_up", il);
+
+    ggml_tensor * experts = nullptr;
+    cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
+    cb(cur, "ffn_moe_gate", il);
+
+    cur = ggml_reglu_split(ctx0, cur, up);
+    cb(cur, "ffn_moe_reglu", il);
+
+    experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
+    cb(experts, "ffn_moe_down", il);
+
+    experts = ggml_mul(ctx0, experts, weights);
+    cb(cur, "ffn_moe_weighted", il);
+
+    ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr };
+
+    assert(n_expert_used > 0);
+
+    // order the views before the adds
+    for (uint32_t i = 0; i < hparams.n_expert_used; ++i) {
+        cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]);
+
+        ggml_build_forward_expand(gf, cur_experts[i]);
+    }
+
+    // aggregate experts
+    // note: here we explicitly use hparams.n_expert_used instead of n_expert_used
+    //       to avoid potentially a large number of add nodes during warmup
+    //       ref: https://github.com/ggml-org/llama.cpp/pull/14753
+    ggml_tensor * moe_out = cur_experts[0];
+
+    for (uint32_t i = 1; i < hparams.n_expert_used; ++i) {
+        moe_out = ggml_add(ctx0, moe_out, cur_experts[i]);
+    }
+
+    if (n_expert_used == 1) {
+        // avoid returning a non-contiguous tensor
+        moe_out = ggml_cont(ctx0, moe_out);
+    }
+
+    cb(moe_out, "ffn_moe_out", il);
+
+    return moe_out;
+}
+
 // input embeddings with optional lora
 ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
     const int64_t n_embd = hparams.n_embd;
index a28a8c4bddad838514e0fe301f982ab0d4b9890c..d4d565754c5003b77264fa9614397ac4e5ff0c6e 100644 (file)
@@ -625,6 +625,18 @@ struct llm_graph_context {
             llama_expert_gating_func_type gating_op,
                      int   il) const;
 
+    ggml_tensor * build_moe_ffn_from_probs(
+             ggml_tensor * cur,
+             ggml_tensor * probs,
+             ggml_tensor * up_exps,
+             ggml_tensor * gate_exps,
+             ggml_tensor * down_exps,
+             ggml_tensor * exp_probs_b,
+                 int64_t   n_expert,
+                 int64_t   n_expert_used,
+            llama_expert_gating_func_type gating_op,
+                     int   il) const;
+
     //
     // inputs
     //
index c6c67d26f9392cd4c81e50160e52690ec96b7f10..7a06368dcda68e1f133fae49f59008db4fb8af07 100644 (file)
@@ -2,9 +2,15 @@
 
 #include "ggml.h"
 
-void llama_hparams::set_swa_pattern(uint32_t n_pattern) {
-    for (uint32_t il = 0; il < n_layer; ++il) {
-        swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1));
+void llama_hparams::set_swa_pattern(uint32_t n_pattern, bool dense_first) {
+    if (dense_first) {
+        for (uint32_t il = 0; il < n_layer; ++il) {
+            swa_layers[il] = n_pattern == 0 || (il % n_pattern != 0);
+        }
+    } else {
+        for (uint32_t il = 0; il < n_layer; ++il) {
+            swa_layers[il] = n_pattern == 0 || (il % n_pattern < (n_pattern - 1));
+        }
     }
 }
 
index ec7fd6a42bf54d2006da127e375b200046c851fb..8b7e2a1130755ac418a326e1c8a6ab5f16b1f7a7 100644 (file)
@@ -140,7 +140,7 @@ struct llama_hparams {
     // for Classifiers
     uint32_t n_cls_out = 1;
 
-    // llama4
+    // llama4 smallthinker
     uint32_t n_moe_layer_step        = 0;
     uint32_t n_no_rope_layer_step    = 4;
     uint32_t n_attn_temp_floor_scale = 8192;
@@ -161,9 +161,10 @@ struct llama_hparams {
     enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE;
 
     // this value n_pattern means that every nth layer is dense (i.e. non-SWA)
+    // dense_first means whether the pattern is start with a dense layer
     // note that if n_pattern == 0, all layers are SWA
     //           if n_pattern == 1, all layers are dense
-    // example: n_pattern = 3
+    // example 1: n_pattern = 3, dense_first = false
     //   il == 0: swa
     //   il == 1: swa
     //   il == 2: dense
@@ -172,7 +173,13 @@ struct llama_hparams {
     //   il == 5: dense
     //   il == 6: swa
     //   etc ...
-    void set_swa_pattern(uint32_t n_pattern);
+    // example 2: n_pattern = 2, dense_first = true
+    //   il == 0: dense
+    //   il == 1: swa
+    //   il == 2: dense
+    //   il == 3: swa
+    //   etc ...
+    void set_swa_pattern(uint32_t n_pattern, bool dense_first = false);
 
     // return true if one of the layers is SWA
     bool is_swa_any() const;
index 71f89e19072ded81e794f7c781ec0f077719475e..e3aa9e6f91af92b7412f09c566f5c9eb8cf5756d 100644 (file)
@@ -1768,6 +1768,29 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     default:   type = LLM_TYPE_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_SMALLTHINKER:
+            {
+                const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false);
+
+                if (found_swa && hparams.n_swa > 0) {
+                    hparams.swa_type      = LLAMA_SWA_TYPE_STANDARD;
+                    hparams.n_swa         = 4096;
+                    hparams.set_swa_pattern(4, true);
+                } else {
+                    hparams.swa_type             = LLAMA_SWA_TYPE_NONE;
+                    hparams.n_no_rope_layer_step = hparams.n_layer;
+                }
+
+                ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH,  hparams.n_ff_exp, false);
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+                ml.get_key(LLM_KV_EXPERT_GATING_FUNC,          hparams.expert_gating_func, false);
+
+                switch (hparams.n_layer) {
+                    case 32: type = LLM_TYPE_4B;  break;
+                    case 52: type = LLM_TYPE_20B; break;
+                    default: type = LLM_TYPE_UNKNOWN;
+                }
+            } break;
         default: throw std::runtime_error("unsupported model architecture");
     }
 
@@ -5165,6 +5188,42 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         }
                     }
                 } break;
+            case LLM_ARCH_SMALLTHINKER:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0);
+
+                        GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for SMALLTHINKER");
+                        GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for SMALLTHINKER");
+
+                        // MoE branch
+                        const int64_t n_ff_exp = hparams.n_ff_exp;
+                        layer.ffn_gate_inp  = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0);
+                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
+                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0);
+                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
+                    }
+                } break;
             default:
                 throw std::runtime_error("unknown architecture");
         }
@@ -5490,6 +5549,11 @@ void llama_model::print_info() const {
         LLAMA_LOG_INFO("%s: expert_weights_norm  = %d\n",     __func__, hparams.expert_weights_norm);
     }
 
+    if (arch == LLM_ARCH_SMALLTHINKER) {
+        LLAMA_LOG_INFO("%s: n_ff_exp             = %d\n",     __func__, hparams.n_ff_exp);
+        LLAMA_LOG_INFO("%s: expert_gating_func   = %s\n",     __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
+    }
+
     vocab.print_info();
 }
 
@@ -17011,6 +17075,119 @@ struct llm_build_lfm2 : public llm_graph_context {
     }
 };
 
+template <bool iswa>
+struct llm_build_smallthinker : public llm_graph_context{
+    llm_build_smallthinker(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params){
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        using inp_attn_type = std::conditional_t<iswa, llm_graph_input_attn_kv_unified_iswa, llm_graph_input_attn_kv_unified>;
+        inp_attn_type * inp_attn = nullptr;
+
+        if constexpr (iswa) {
+            inp_attn = build_attn_inp_kv_unified_iswa();
+        } else {
+            inp_attn = build_attn_inp_kv_unified();
+        }
+
+        ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA  = inpL;
+            ggml_tensor * probs  = nullptr;
+
+            probs = build_lora_mm(model.layers[il].ffn_gate_inp, inpL);  // [n_expert, n_tokens]
+            cb(probs, "ffn_moe_logits", il);
+
+            // norm
+            cur = build_norm(inpL,model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self_attention
+            {
+                // compute Q and K and RoPE them
+                struct ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                struct ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                struct ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+                if (hparams.n_no_rope_layer_step == n_layer || il % hparams.n_no_rope_layer_step != 0) {
+                    Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                                     ext_factor, attn_factor, beta_fast, beta_slow);
+
+                    Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                                     ext_factor, attn_factor, beta_fast, beta_slow);
+                }
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+
+                cur = build_attn(inp_attn,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1 && inp_out_ids) {
+                cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+                probs = ggml_get_rows(ctx0, probs, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // MoE branch
+            cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            ggml_tensor * ffn_out = build_moe_ffn_from_probs(cur, probs, model.layers[il].ffn_up_exps,
+                                                model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps,
+                                                nullptr, n_expert, n_expert_used,
+                                                static_cast<llama_expert_gating_func_type>(hparams.expert_gating_func), il);
+
+            cb(ffn_out, "ffn_out", il);
+            cur = ffn_out;
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
+        cb(cur, "result_norm", -1);
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
 llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
     llama_memory_i * res;
 
@@ -17449,6 +17626,14 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
             {
                 llm = std::make_unique<llm_build_lfm2>(*this, params);
             } break;
+        case LLM_ARCH_SMALLTHINKER:
+            {
+                if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) {
+                    llm = std::make_unique<llm_build_smallthinker<true>> (*this, params);
+                } else {
+                    llm = std::make_unique<llm_build_smallthinker<false>>(*this, params);
+                }
+            } break;
         default:
             GGML_ABORT("fatal error");
     }
@@ -17647,6 +17832,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_DOTS1:
         case LLM_ARCH_HUNYUAN_MOE:
         case LLM_ARCH_LFM2:
+        case LLM_ARCH_SMALLTHINKER:
             return LLAMA_ROPE_TYPE_NEOX;
 
         case LLM_ARCH_QWEN2VL: