]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
model: mistral small 4 support (#20649)
authorXuan-Son Nguyen <redacted>
Mon, 16 Mar 2026 23:31:14 +0000 (00:31 +0100)
committerGitHub <redacted>
Mon, 16 Mar 2026 23:31:14 +0000 (00:31 +0100)
* model: mistral small 4 support

* fix test

* fix test (2)

* Apply suggestions from code review

Co-authored-by: Sigbjørn Skjæret <redacted>
* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <redacted>
* change newline

---------

Co-authored-by: Sigbjørn Skjæret <redacted>
convert_hf_to_gguf.py
gguf-py/gguf/constants.py
src/llama-arch.cpp
src/llama-arch.h
src/llama-model.cpp
tests/test-llama-archs.cpp

index b4ff8dd95952e7ec539fd75cb9d1798a5d0cb20a..46469c862000065d2b5892a0388134a7e55e8b6e 100755 (executable)
@@ -298,11 +298,16 @@ class ModelBase:
                 scale = scale.float()
 
                 if block_size is not None:
+                    dim_offset = scale.ndim - len(block_size)
                     for i, size in enumerate(block_size):
-                        scale = scale.repeat_interleave(size, i)
+                        scale = scale.repeat_interleave(size, dim_offset + i)
                     # unpad the scale (e.g. when the tensor size isn't a multiple of the block size)
                     scale = scale[tuple(slice(0, size) for size in weight.shape)]
 
+                # align scale dims to weight for correct broadcasting (e.g. [128] -> [128, 1, 1])
+                while scale.ndim < weight.ndim:
+                    scale = scale.unsqueeze(-1)
+
                 return weight.float() * scale
 
             # ref: https://github.com/ModelCloud/GPTQModel/blob/037c5c0f6c9e33c500d975b038d02e7ca437546d/gptqmodel/nn_modules/qlinear/__init__.py#L437-L476
@@ -393,7 +398,7 @@ class ModelBase:
             elif quant_method == "fp8":
                 block_size = quant_config.get("weight_block_size")
                 for name in self.model_tensors.keys():
-                    if name.endswith(".weight_scale_inv"):
+                    if name.endswith("_scale_inv"):
                         weight_name = name.removesuffix("_scale_inv")
                         w = self.model_tensors[weight_name]
                         s = self.model_tensors[name]
@@ -401,6 +406,8 @@ class ModelBase:
                         tensors_to_remove.append(name)
                     if name.endswith(".activation_scale"):  # unused
                         tensors_to_remove.append(name)
+                    if name.endswith("_activation_scale"):  # Mistral-Small-4-119B-2602, unused
+                        tensors_to_remove.append(name)
                     # mistral format
                     if name.endswith(".qscale_weight"):
                         weight_name = name.removesuffix("qscale_weight") + "weight"
@@ -3031,10 +3038,16 @@ class LlavaVisionModel(MmprojModel):
     def get_token_id(self, token: str) -> int:
         tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
         with open(tokenizer_config_file, "r", encoding="utf-8") as f:
-            added_tokens_decoder = json.load(f)['added_tokens_decoder']
+            added_tokens_decoder = json.load(f).get('added_tokens_decoder') or {}
             for id_, token_data in added_tokens_decoder.items():
-                if token_data["content"] == token:
+                if token_data.get("content") == token:
                     return int(id_)
+            # fallthrough to tokenizer.json
+        with open(self.dir_model / "tokenizer.json", "r", encoding="utf-8") as f:
+            tokenizer_json = json.load(f)
+            for token_data in tokenizer_json["added_tokens"]:
+                if token_data["content"] == token:
+                    return int(token_data["id"])
         raise ValueError(f"Token '{token}' not found in tokenizer config.")
 
     def set_gguf_parameters(self):
@@ -3198,40 +3211,6 @@ class Llama4VisionModel(MmprojModel):
                 yield from super().modify_tensors(data_torch, name, bid)
 
 
-@ModelBase.register(
-    "Mistral3ForConditionalGeneration",
-    "Ministral3ForCausalLM",
-)
-class Mistral3Model(LlamaModel):
-    model_arch = gguf.MODEL_ARCH.MISTRAL3
-
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-        # for compatibility, we use LLAMA arch for older models
-        # TODO: remove this once everyone has migrated to newer version of llama.cpp
-        if self.hparams.get("model_type") != "ministral3":
-            self.model_arch = gguf.MODEL_ARCH.LLAMA
-            self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch]
-            self.gguf_writer.add_architecture()
-            self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
-
-    def set_gguf_parameters(self):
-        super().set_gguf_parameters()
-        rope_params = self.rope_parameters
-        if self.hparams.get("model_type") == "ministral3":
-            assert rope_params, "ministral3 must have 'rope_parameters' config"
-            assert rope_params["rope_type"] == "yarn", "ministral3 rope_type must be 'yarn'"
-            self.gguf_writer.add_rope_scaling_yarn_log_mul(rope_params["mscale_all_dim"])
-            self.gguf_writer.add_attn_temperature_scale(rope_params["llama_4_scaling_beta"])
-
-    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
-        name = name.replace("language_model.", "")
-        if "multi_modal_projector" in name or "vision_tower" in name:
-            return
-
-        yield from super().modify_tensors(data_torch, name, bid)
-
-
 @ModelBase.register("DeciLMForCausalLM")
 class DeciModel(TextModel):
     model_arch = gguf.MODEL_ARCH.DECI
@@ -8271,6 +8250,8 @@ class DeepseekV2Model(TextModel):
     # TODO @ngxson : remove this when we support MTP for deepseek models
     skip_mtp = True
 
+    merge_expert = True
+
     def set_vocab(self):
         try:
             self._set_vocab_gpt2()
@@ -8409,7 +8390,7 @@ class DeepseekV2Model(TextModel):
                 return
 
         # process the experts separately
-        if name.find("mlp.experts") != -1:
+        if self.merge_expert and name.find("mlp.experts") != -1:
             n_experts = self.hparams["n_routed_experts"]
             assert bid is not None
 
@@ -8468,6 +8449,69 @@ class DeepseekV2Model(TextModel):
                 raise ValueError(f"Unprocessed experts: {experts}")
 
 
+@ModelBase.register(
+    "Mistral3ForConditionalGeneration",
+    "Ministral3ForCausalLM",
+)
+class Mistral3Model(TextModel):
+    class Ministral3Model(LlamaModel):
+        model_arch = gguf.MODEL_ARCH.MISTRAL3
+
+        def set_gguf_parameters(self):
+            super().set_gguf_parameters()
+            rope_params = self.rope_parameters
+            if self.hparams.get("model_type") == "ministral3":
+                assert rope_params, "ministral3 must have 'rope_parameters' config"
+                assert rope_params["rope_type"] == "yarn", "ministral3 rope_type must be 'yarn'"
+                self.gguf_writer.add_rope_scaling_yarn_log_mul(rope_params["mscale_all_dim"])
+                self.gguf_writer.add_attn_temperature_scale(rope_params["llama_4_scaling_beta"])
+
+        def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
+            name = name.replace("language_model.", "")
+            if "multi_modal_projector" in name or "vision_tower" in name:
+                return
+
+            yield from super().modify_tensors(data_torch, name, bid)
+
+    class Mistral4Model(DeepseekV2Model):
+        model_arch = gguf.MODEL_ARCH.MISTRAL4
+        skip_mtp = False # model contains no MTP layers, so no need to skip
+        merge_expert = False # experts are already stacked as 3D
+
+        def modify_tensors(self, data_torch, name, bid):
+            if name.endswith(".down_proj") or name.endswith(".gate_up_proj"):
+                name = name + ".weight"
+            yield from super().modify_tensors(data_torch, name, bid)
+
+    model_arch = gguf.MODEL_ARCH.MISTRAL3 # unused
+    impl: TextModel
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        if self.hparams.get("model_type") == "mistral4":
+            self.impl = Mistral3Model.Mistral4Model(*args, **kwargs)
+        else:
+            self.impl = Mistral3Model.Ministral3Model(*args, **kwargs)
+
+    def set_vocab(self):
+        self.impl.set_vocab()
+
+    def set_gguf_parameters(self):
+        self.impl.set_gguf_parameters()
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
+        yield from self.impl.modify_tensors(data_torch, name, bid)
+
+    def prepare_tensors(self):
+        self.impl.prepare_tensors()
+
+    def write_vocab(self):
+        self.impl.write_vocab()
+
+    def write(self):
+        self.impl.write()
+
+
 @ModelBase.register("MiniMaxM2ForCausalLM")
 class MiniMaxM2Model(TextModel):
     model_arch = gguf.MODEL_ARCH.MINIMAXM2
index bf617382d0ac90bf9350abcfdd6a63cefa4cc5a8..0a032e9039cd8b21d5ee0f814a61e6877f036574 100644 (file)
@@ -478,6 +478,7 @@ class MODEL_ARCH(IntEnum):
     RND1             = auto()
     PANGU_EMBED      = auto()
     MISTRAL3         = auto()
+    MISTRAL4         = auto()
     PADDLEOCR        = auto()
     MIMO2            = auto()
     STEP35           = auto()
@@ -924,6 +925,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.RND1:             "rnd1",
     MODEL_ARCH.PANGU_EMBED:      "pangu-embedded",
     MODEL_ARCH.MISTRAL3:         "mistral3",
+    MODEL_ARCH.MISTRAL4:         "mistral4",
     MODEL_ARCH.PADDLEOCR:        "paddleocr",
     MODEL_ARCH.MIMO2:            "mimo2",
     MODEL_ARCH.STEP35:           "step35",
@@ -3538,6 +3540,37 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_DOWN_EXP,
         MODEL_TENSOR.FFN_UP_EXP,
     ],
+    MODEL_ARCH.MISTRAL4: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ROPE_FREQS,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_Q_A,
+        MODEL_TENSOR.ATTN_Q_B,
+        MODEL_TENSOR.ATTN_KV_A_MQA,
+        MODEL_TENSOR.ATTN_KV_B,
+        MODEL_TENSOR.ATTN_K_B,
+        MODEL_TENSOR.ATTN_V_B,
+        MODEL_TENSOR.ATTN_Q_A_NORM,
+        MODEL_TENSOR.ATTN_KV_A_NORM,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.ATTN_ROT_EMBD,
+        MODEL_TENSOR.FFN_GATE_INP,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_GATE,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+        MODEL_TENSOR.FFN_GATE_EXP,
+        MODEL_TENSOR.FFN_DOWN_EXP,
+        MODEL_TENSOR.FFN_UP_EXP,
+        MODEL_TENSOR.FFN_GATE_UP_EXP,
+        MODEL_TENSOR.FFN_GATE_SHEXP,
+        MODEL_TENSOR.FFN_DOWN_SHEXP,
+        MODEL_TENSOR.FFN_UP_SHEXP,
+        MODEL_TENSOR.FFN_EXP_PROBS_B,
+    ],
     MODEL_ARCH.MIMO2: [
         MODEL_TENSOR.TOKEN_EMBD,
         MODEL_TENSOR.OUTPUT_NORM,
index 799d16167ba781b7ab593dd37e3ca99489a35a42..84dc6d8f1b6543088ec096fa82084bd5a9c84846 100644 (file)
@@ -123,6 +123,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_RND1,             "rnd1"             },
     { LLM_ARCH_PANGU_EMBED,      "pangu-embedded"   },
     { LLM_ARCH_MISTRAL3,         "mistral3"         },
+    { LLM_ARCH_MISTRAL4,         "mistral4"         },
     { LLM_ARCH_PADDLEOCR,        "paddleocr"        },
     { LLM_ARCH_MIMO2,            "mimo2"            },
     { LLM_ARCH_STEP35,           "step35"           },
@@ -1589,6 +1590,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
                 LLM_TENSOR_FFN_UP_SHEXP,
             };
         case LLM_ARCH_DEEPSEEK2:
+        case LLM_ARCH_MISTRAL4:
             return {
                 LLM_TENSOR_TOKEN_EMBD,
                 LLM_TENSOR_OUTPUT_NORM,
index b1b1dcf18839849e22d8f5aa607d8748a44bf294..9b9eec2f5c8e3151be6e0ad143718106070f89a7 100644 (file)
@@ -127,6 +127,7 @@ enum llm_arch {
     LLM_ARCH_RND1,
     LLM_ARCH_PANGU_EMBED,
     LLM_ARCH_MISTRAL3,
+    LLM_ARCH_MISTRAL4,
     LLM_ARCH_PADDLEOCR,
     LLM_ARCH_MIMO2,
     LLM_ARCH_STEP35,
index bae02e32b1730451b228e4a78c5b3884a01e6a8d..85db938a7ad417eae39cfdb40c899ff34ef70918 100644 (file)
@@ -1587,6 +1587,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                 }
             } break;
         case LLM_ARCH_DEEPSEEK2:
+        case LLM_ARCH_MISTRAL4:
             {
                 // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B, Kanana-2-30B-A3B
                 const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26 || (hparams.n_layer == 48 && n_vocab == 128256));
@@ -4883,6 +4884,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                     }
                 } break;
             case LLM_ARCH_DEEPSEEK2:
+            case LLM_ARCH_MISTRAL4:
                 {
                     const bool is_mla = hparams.is_mla();
 
@@ -7850,7 +7852,7 @@ void llama_model::print_info() const {
         LLAMA_LOG_INFO("%s: expert_weights_scale  = %.1f\n",   __func__, hparams.expert_weights_scale);
     }
 
-    if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_GLM_DSA) {
+    if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_GLM_DSA || arch == LLM_ARCH_MISTRAL4) {
         LLAMA_LOG_INFO("%s: n_layer_dense_lead    = %d\n",     __func__, hparams.n_layer_dense_lead);
         LLAMA_LOG_INFO("%s: n_lora_q              = %d\n",     __func__, hparams.n_lora_q);
         LLAMA_LOG_INFO("%s: n_lora_kv             = %d\n",     __func__, hparams.n_lora_kv);
@@ -8428,6 +8430,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
             } break;
         case LLM_ARCH_DEEPSEEK2:
         case LLM_ARCH_GLM_DSA:
+        case LLM_ARCH_MISTRAL4:
             {
                 llm = std::make_unique<llm_build_deepseek2>(*this, params);
             } break;
@@ -8839,6 +8842,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_ERNIE4_5:
         case LLM_ARCH_ERNIE4_5_MOE:
         case LLM_ARCH_MISTRAL3:
+        case LLM_ARCH_MISTRAL4:
         case LLM_ARCH_LLAMA_EMBED:
         case LLM_ARCH_MAINCODER:
         case LLM_ARCH_GLM_DSA:
index 014b3f2b149928cc1c282f2600a7fcfc67386f25..d51c09e99f52db154018664a397025d9cce7ca45 100644 (file)
@@ -90,7 +90,10 @@ static gguf_context_ptr get_gguf_ctx(const llm_arch arch, const bool moe) {
         n_embd = 64;
         n_head = 1;
         n_ff   = 96;
-    } else if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_GLM_DSA || arch == LLM_ARCH_KIMI_LINEAR) {
+    } else if (arch == LLM_ARCH_DEEPSEEK2
+            || arch == LLM_ARCH_GLM_DSA
+            || arch == LLM_ARCH_KIMI_LINEAR
+            || arch == LLM_ARCH_MISTRAL4) {
         n_embd = 128;
         n_head = 1;
         n_ff   = 192;
@@ -145,7 +148,10 @@ static gguf_context_ptr get_gguf_ctx(const llm_arch arch, const bool moe) {
     }
 
     ms.add_kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, 8.0f);
-    if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_GLM_DSA  || arch == LLM_ARCH_KIMI_LINEAR) {
+    if (arch == LLM_ARCH_DEEPSEEK2
+            || arch == LLM_ARCH_GLM_DSA
+            || arch == LLM_ARCH_KIMI_LINEAR
+            || arch == LLM_ARCH_MISTRAL4) {
         ms.add_kv(LLM_KV_ATTENTION_KEY_LENGTH,       uint32_t(576));
         ms.add_kv(LLM_KV_ATTENTION_VALUE_LENGTH,     uint32_t(512));
         ms.add_kv(LLM_KV_ROPE_DIMENSION_COUNT,       uint32_t(64));
@@ -319,6 +325,7 @@ static bool moe_mandatory(const llm_arch arch) {
         case LLM_ARCH_MIMO2:
         case LLM_ARCH_KIMI_LINEAR:
         case LLM_ARCH_STEP35:
+        case LLM_ARCH_MISTRAL4:
             return true;
         default:
             return false;