]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
model : add SmolLM3 (#14581)
authorXuan-Son Nguyen <redacted>
Tue, 8 Jul 2025 16:07:01 +0000 (18:07 +0200)
committerGitHub <redacted>
Tue, 8 Jul 2025 16:07:01 +0000 (18:07 +0200)
* Init - first pass.

* Model -> ModelBase.

* fix errors in conversion.

* Update the graph.

* up.

* up.

* wip

* cgraph ok

* rm redundant code

---------

Co-authored-by: Vaibhavs10 <redacted>
convert_hf_to_gguf.py
docs/development/HOWTO-add-model.md
gguf-py/gguf/constants.py
src/llama-arch.cpp
src/llama-arch.h
src/llama-model.cpp

index cd18c9a43800aaba89acfcb7b9b9e8f24dd0721f..3f3dfb416c1fc24a32128ba1cbd38b00abcd8314 100755 (executable)
@@ -6687,6 +6687,11 @@ class HunYuanMoEModel(TextModel):
             if len(experts) > 0:
                 raise ValueError(f"Unprocessed experts: {experts}")
 
+
+@ModelBase.register("SmolLM3ForCausalLM")
+class SmolLM3Model(LlamaModel):
+    model_arch = gguf.MODEL_ARCH.SMOLLM3
+
 ###### CONVERSION LOGIC ######
 
 
index 7f71e0247ddc75c14367e159d044fafb5ac63ee3..51e0b0b20f58d1bf44ef7c087c63159e1b0cfd99 100644 (file)
@@ -83,20 +83,22 @@ NOTE: Tensor names must end with `.weight` or `.bias` suffixes, that is the conv
 
 ### 2. Define the model architecture in `llama.cpp`
 
-The model params and tensors layout must be defined in `llama.cpp`:
-1. Define a new `llm_arch`
-2. Define the tensors layout in `LLM_TENSOR_NAMES`
-3. Add any non-standard metadata in `llm_load_hparams`
-4. Create the tensors for inference in `llm_load_tensors`
-5. If the model has a RoPE operation, add the rope type in `llama_rope_type`
+The model params and tensors layout must be defined in `llama.cpp` source files:
+1. Define a new `llm_arch` enum value in `src/llama-arch.h`.
+2. In `src/llama-arch.cpp`:
+    - Add the architecture name to the `LLM_ARCH_NAMES` map.
+    - Add the tensor mappings to the `LLM_TENSOR_NAMES` map.
+3. Add any non-standard metadata loading in the `llama_model_loader` constructor in `src/llama-model-loader.cpp`.
+4. If the model has a RoPE operation, add a case for the architecture in `llama_model_rope_type` function in `src/llama-model.cpp`.
 
 NOTE: The dimensions in `ggml` are typically in the reverse order of the `pytorch` dimensions.
 
 ### 3. Build the GGML graph implementation
 
-This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `llama_build_graph`.
-
-Have a look at existing implementations like `build_llama`, `build_dbrx` or `build_bert`.
+This is the funniest part, you have to provide the inference graph implementation of the new model architecture in `src/llama-model.cpp`.
+Create a new struct that inherits from `llm_graph_context` and implement the graph-building logic in its constructor.
+Have a look at existing implementations like `llm_build_llama`, `llm_build_dbrx` or `llm_build_bert`.
+Then, in the `llama_model::build_graph` method, add a case for your architecture to instantiate your new graph-building struct.
 
 Some `ggml` backends do not support all operations. Backend implementations can be added in a separate PR.
 
index 729bec927c6f379123175cca9e988605097eaed9..e938f8fa664dfea1ed5d7ad91724017bebff30a8 100644 (file)
@@ -358,6 +358,7 @@ class MODEL_ARCH(IntEnum):
     ARCEE            = auto()
     ERNIE4_5         = auto()
     HUNYUAN_MOE      = auto()
+    SMOLLM3          = auto()
 
 
 class VISION_PROJECTOR_TYPE(IntEnum):
@@ -662,6 +663,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.ARCEE:            "arcee",
     MODEL_ARCH.ERNIE4_5:         "ernie4_5",
     MODEL_ARCH.HUNYUAN_MOE:      "hunyuan-moe",
+    MODEL_ARCH.SMOLLM3:          "smollm3",
 }
 
 VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@@ -2234,6 +2236,22 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_DOWN_SHEXP,
         MODEL_TENSOR.FFN_UP_SHEXP,
     ],
+    MODEL_ARCH.SMOLLM3: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ROPE_FREQS,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.ATTN_ROT_EMBD,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_GATE,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+    ],
     # TODO
 }
 
index f1e443ec21b20305e340136a07b70bd235892b10..9af9c2ad604d56aa42844be8ba79b6e58a497aaa 100644 (file)
@@ -79,6 +79,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_ARCEE,            "arcee"            },
     { LLM_ARCH_ERNIE4_5,         "ernie4_5"         },
     { LLM_ARCH_HUNYUAN_MOE,      "hunyuan-moe"      },
+    { LLM_ARCH_SMOLLM3,          "smollm3"          },
     { LLM_ARCH_UNKNOWN,          "(unknown)"        },
 };
 
@@ -1724,6 +1725,23 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
         },
     },
+    {
+        LLM_ARCH_SMOLLM3,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,     "token_embd"            },
+            { LLM_TENSOR_OUTPUT_NORM,    "output_norm"           },
+            { LLM_TENSOR_OUTPUT,         "output"                },
+            { LLM_TENSOR_ATTN_NORM,      "blk.%d.attn_norm"      },
+            { LLM_TENSOR_ATTN_Q,         "blk.%d.attn_q"         },
+            { LLM_TENSOR_ATTN_K,         "blk.%d.attn_k"         },
+            { LLM_TENSOR_ATTN_V,         "blk.%d.attn_v"         },
+            { LLM_TENSOR_ATTN_OUT,       "blk.%d.attn_output"    },
+            { LLM_TENSOR_FFN_NORM,       "blk.%d.ffn_norm"       },
+            { LLM_TENSOR_FFN_GATE,       "blk.%d.ffn_gate"       },
+            { LLM_TENSOR_FFN_DOWN,       "blk.%d.ffn_down"       },
+            { LLM_TENSOR_FFN_UP,         "blk.%d.ffn_up"         },
+        },
+    },
 };
 
 static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
index f1261f69ff7a4fb06e1a4d9cfc3c06324bf1ed79..ba5d03fa24ebe2c5a45237e8c6d095432107409a 100644 (file)
@@ -83,6 +83,7 @@ enum llm_arch {
     LLM_ARCH_ARCEE,
     LLM_ARCH_ERNIE4_5,
     LLM_ARCH_HUNYUAN_MOE,
+    LLM_ARCH_SMOLLM3,
     LLM_ARCH_UNKNOWN,
 };
 
index c9f58d44146d2ef02561854b5c213d7e28140ba9..fc4e9a5af004d235153a9bf4263ec4fa894b7637 100644 (file)
@@ -1561,6 +1561,16 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     default: type = LLM_TYPE_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_SMOLLM3:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+                hparams.n_no_rope_layer_step = 4;
+
+                switch (hparams.n_layer) {
+                    case 36: type = LLM_TYPE_3B; break;
+                    default: type = LLM_TYPE_UNKNOWN;
+                }
+            } break;
         default: throw std::runtime_error("unsupported model architecture");
     }
 
@@ -4524,6 +4534,35 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, 0);
                     }
                 } break;
+            case LLM_ARCH_SMOLLM3:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd,   n_ff}, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
             default:
                 throw std::runtime_error("unknown architecture");
         }
@@ -14846,6 +14885,142 @@ struct llm_build_hunyuan_moe : public llm_graph_context {
         cb(cur, "result_norm", -1);
         res->t_embd = cur;
 
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
+struct llm_build_smollm3 : public llm_graph_context {
+    llm_build_smollm3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified();
+
+        const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
+
+        ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            const bool use_rope = (il + 1) % hparams.n_no_rope_layer_step != 0;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+                if (use_rope) {
+                    Qcur = ggml_rope_ext(
+                            ctx0, Qcur, inp_pos, nullptr,
+                            n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                            ext_factor, attn_factor, beta_fast, beta_slow
+                            );
+
+                    Kcur = ggml_rope_ext(
+                            ctx0, Kcur, inp_pos, nullptr,
+                            n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                            ext_factor, attn_factor, beta_fast, beta_slow
+                            );
+                }
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
+                cb(cur, "attn_out", il);
+            }
+
+            if (il == n_layer - 1 && inp_out_ids) {
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            {
+                cur = build_norm(ffn_inp,
+                        model.layers[il].ffn_norm, NULL,
+                        LLM_NORM_RMS, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = build_ffn(cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,   NULL,
+                        model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
+                        NULL,
+                        LLM_FFN_SILU, LLM_FFN_PAR, il);
+                cb(cur, "ffn_out", il);
+            }
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
         // lm_head
         cur = build_lora_mm(model.output, cur);
 
@@ -15240,6 +15415,10 @@ llm_graph_result_ptr llama_model::build_graph(
             {
                 llm = std::make_unique<llm_build_hunyuan_moe>(*this, params, gf);
             } break;
+        case LLM_ARCH_SMOLLM3:
+            {
+                llm = std::make_unique<llm_build_smollm3>(*this, params, gf);
+            } break;
         default:
             GGML_ABORT("fatal error");
     }
@@ -15391,6 +15570,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_CHAMELEON:
         case LLM_ARCH_BAILINGMOE:
         case LLM_ARCH_NEO_BERT:
+        case LLM_ARCH_SMOLLM3:
         case LLM_ARCH_ARCEE:
         case LLM_ARCH_ERNIE4_5:
             return LLAMA_ROPE_TYPE_NORM;