]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
mtmd : add **vision** support for Mistral Small 3.1 (#13231)
authorXuan-Son Nguyen <redacted>
Thu, 1 May 2025 15:05:42 +0000 (17:05 +0200)
committerGitHub <redacted>
Thu, 1 May 2025 15:05:42 +0000 (17:05 +0200)
* convert ok

* load ok, missing patch merger

* ah sheet it works

* update llava/readme

* add test

* fix test

convert_hf_to_gguf.py
examples/llava/README.md
examples/llava/clip-impl.h
examples/llava/clip.cpp
examples/llava/mtmd-cli.cpp
examples/llava/tests.sh
gguf-py/gguf/constants.py
gguf-py/gguf/gguf_writer.py
gguf-py/gguf/tensor_mapping.py

index 123df801bf095b67e768e1ca1311e939c475425b..04ca646b503cab7974053367f9066dbf79c10a7f 100755 (executable)
@@ -1899,7 +1899,10 @@ class LlamaModel(TextModel):
                 raise ValueError(f"Unprocessed experts: {experts}")
 
 
-@ModelBase.register("LlavaForConditionalGeneration")
+@ModelBase.register(
+    "LlavaForConditionalGeneration", # pixtral
+    "Mistral3ForConditionalGeneration", # mistral small 3.1
+)
 class LlavaVisionModel(VisionModel):
     img_break_tok_id = -1
 
@@ -1908,17 +1911,38 @@ class LlavaVisionModel(VisionModel):
         if self.hparams["model_type"] == "pixtral":
             # layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py
             self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5)
-            self.img_break_tok_id = 12 # see tokenizer_config.json
+            self.img_break_tok_id = self.get_token_id("[IMG_BREAK]")
+            logger.info(f"Image break token id: {self.img_break_tok_id}")
         else:
             raise ValueError(f"Unsupported model type: {self.hparams['model_type']}")
 
+    def get_token_id(self, token: str) -> int:
+        tokenizer_config_file = self.dir_model / 'tokenizer_config.json'
+        with open(tokenizer_config_file, "r", encoding="utf-8") as f:
+            added_tokens_decoder = json.load(f)['added_tokens_decoder']
+            for id_, token_data in added_tokens_decoder.items():
+                if token_data["content"] == token:
+                    return int(id_)
+        raise ValueError(f"Token '{token}' not found in tokenizer config.")
+
     def set_gguf_parameters(self):
         super().set_gguf_parameters()
         hparams = self.hparams
         if hparams["model_type"] == "pixtral":
             self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.PIXTRAL)
             self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
-            self.gguf_writer.add_vision_use_silu(True)
+
+            # hidden_act
+            if hparams["hidden_act"] == "silu":
+                self.gguf_writer.add_vision_use_silu(True)
+            elif hparams["hidden_act"] == "gelu":
+                self.gguf_writer.add_vision_use_gelu(True)
+            else:
+                raise ValueError(f"Unsupported hidden_act: {hparams['hidden_act']}")
+
+            # spatial_merge_size
+            if "spatial_merge_size" in self.global_config:
+                self.gguf_writer.add_vision_spatial_merge_size(self.global_config["spatial_merge_size"])
 
     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
         del bid  # unused
index f58d9de7107e868e2334866de48b8d8ba91451c7..3b62627ce829fec6d88fa7e65be643463b8f396e 100644 (file)
@@ -34,6 +34,9 @@ llama-mtmd-cli -hf ggml-org/SmolVLM2-500M-Video-Instruct-GGUF
 
 # Pixtral 12B
 llama-mtmd-cli -hf ggml-org/pixtral-12b-GGUF
+
+# Mistral Small 3.1 24B (IQ2_M quantization)
+llama-mtmd-cli -hf ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF --chat-template mistral-v7
 ```
 
 ## How it works and what is `mmproj`?
@@ -73,3 +76,4 @@ For the following models, you can use `convert_hf_to_gguf.py`with `--mmproj` fla
 - SmolVLM (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
 - SmolVLM2 (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB))
 - [Pixtral 12B](https://huggingface.co/mistral-community/pixtral-12b) - only works with `transformers`-compatible checkpoint
+- [Mistral Small 3.1 24B](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503)
index 66cb21ef1a756c7f5d489ffb1c00711b538a4bd5..b575ca4d7c2a9a9b8ae16b183aa2d419bf886dc6 100644 (file)
@@ -31,6 +31,7 @@
 #define KEY_FEATURE_LAYER       "clip.vision.feature_layer"
 #define KEY_PROJ_SCALE_FACTOR   "clip.vision.projector.scale_factor"
 #define KEY_PROJ_TYPE           "clip.projector_type"
+#define KEY_SPATIAL_MERGE_SIZE  "clip.vision.spatial_merge_size"
 
 #define KEY_USE_GLU_MLP         "clip.use_glu_mlp"  // for qwen2.5vl
 #define KEY_USE_RMS_NORM        "clip.use_rms_norm" // for qwen2.5vl
 #define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s"
 #define TN_MVLM_PROJ_PEG   "mm.model.peg.%d.%s"
 #define TN_IMAGE_NEWLINE   "model.image_newline"
+#define TN_MM_INP_NORM     "mm.input_norm.weight"
 #define TN_MM_INP_PROJ     "mm.input_projection.weight" // gemma3
 #define TN_MM_SOFT_EMB_N   "mm.soft_emb_norm.weight"    // gemma3
 #define TN_MM_PROJECTOR    "mm.model.fc.weight"         // idefics3
+#define TN_MM_PATCH_MERGER "mm.patch_merger.weight"     // mistral small 3.1
 #define TN_TOK_IMG_BREAK   "v.token_embd.img_break"     // pixtral
 
 // mimicpmv
index ad3e7df1d8a3a8c241f673df5601a68ede08619f..984e300e7538add1c2243769bd456b3ca9a3ee0d 100644 (file)
@@ -172,6 +172,7 @@ struct clip_hparams {
     std::unordered_set<int32_t> vision_feature_layer;
     int32_t attn_window_size = 0;
     int32_t n_wa_pattern = 0;
+    int32_t spatial_merge_size = 0;
 };
 
 struct clip_layer {
@@ -232,6 +233,7 @@ struct clip_vision_model {
     struct ggml_tensor * projection;
 
     // LLaVA projection
+    struct ggml_tensor * mm_input_norm_w = nullptr;
     struct ggml_tensor * mm_0_w = nullptr;
     struct ggml_tensor * mm_0_b = nullptr;
     struct ggml_tensor * mm_2_w = nullptr;
@@ -311,6 +313,7 @@ struct clip_vision_model {
 
     // pixtral
     struct ggml_tensor * token_embd_img_break = nullptr;
+    struct ggml_tensor * mm_patch_merger_w = nullptr;
 };
 
 struct clip_ctx {
@@ -637,6 +640,7 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
     const int d_head      = hidden_size / n_head;
     const int n_layer     = hparams.n_layer;
     const float eps       = hparams.eps;
+    const int n_merge     = hparams.spatial_merge_size;
 
     struct ggml_init_params params = {
         /*.mem_size   =*/ ctx->buf_compute_meta.size(),
@@ -721,7 +725,13 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
         {
             ggml_tensor * gate_proj = ggml_mul_mat(ctx0, model.layers[il].ff_gate_w, cur);
             ggml_tensor * up_proj   = ggml_mul_mat(ctx0, model.layers[il].ff_up_w,   cur);
-            gate_proj = ggml_silu(ctx0, gate_proj); // pixtral uses silu
+            if (ctx->use_silu) {
+                gate_proj = ggml_silu(ctx0, gate_proj);
+            } else if (ctx->use_gelu) {
+                gate_proj = ggml_gelu(ctx0, gate_proj);
+            } else {
+                GGML_ABORT("Pixtral: Unsupported activation");
+            }
             cur = ggml_mul(ctx0, up_proj, gate_proj);
             cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur);
         }
@@ -732,14 +742,42 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
         embeddings = cur;
     }
 
-    // LlavaMultiModalProjector (with GELU activation)
+    // mistral small 3.1 patch merger
+    // ref: https://github.com/huggingface/transformers/blob/7a3e208892c06a5e278144eaf38c8599a42f53e7/src/transformers/models/mistral3/modeling_mistral3.py#L67
+    if (model.mm_patch_merger_w) {
+        GGML_ASSERT(hparams.spatial_merge_size > 0);
+
+        ggml_tensor * cur = embeddings;
+        cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.mm_input_norm_w);
+
+        // reshape image tokens to 2D grid
+        cur = ggml_reshape_3d(ctx0, cur, hidden_size, n_patches_x, n_patches_y);
+        cur = ggml_permute(ctx0, cur, 2, 0, 1, 3); // [x, y, hidden_size]
+        cur = ggml_cont(ctx0, cur);
+
+        // torch.nn.functional.unfold is just an im2col under the hood
+        // we just need a dummy kernel to make it work
+        ggml_tensor * kernel = ggml_view_3d(ctx0, cur, n_merge, n_merge, cur->ne[2], 0, 0, 0);
+        cur = ggml_im2col(ctx0, kernel, cur, n_merge, n_merge, 0, 0, 1, 1, true, inp->type);
+
+        // project to hidden_size
+        cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]);
+        cur = ggml_mul_mat(ctx0, model.mm_patch_merger_w, cur);
+        embeddings = cur;
+    }
+
+    // LlavaMultiModalProjector (always using GELU activation)
     {
         embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
-        embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
+        if (model.mm_1_b) {
+            embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
+        }
 
         embeddings = ggml_gelu(ctx0, embeddings);
         embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
-        embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
+        if (model.mm_2_b) {
+            embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
+        }
     }
 
     // arrangement of the [IMG_BREAK] token
@@ -749,11 +787,14 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
         // and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension
         // after the concatenation, we have a tensor with shape [hidden_size, n_patches_per_row + 1, n_rows]
 
+        const int p_y             = n_merge > 0 ? n_patches_y / n_merge : n_patches_y;
+        const int p_x             = n_merge > 0 ? n_patches_x / n_merge : n_patches_x;
+        const int p_total         = p_x * p_y;
         const int n_embd_text     = embeddings->ne[0];
-        const int n_tokens_output = num_patches + n_patches_y - 1; // one [IMG_BREAK] per row, except the last row
+        const int n_tokens_output = p_total + p_y - 1; // one [IMG_BREAK] per row, except the last row
 
-        ggml_tensor * cur = ggml_reshape_3d(ctx0, embeddings, n_embd_text, n_patches_x, n_patches_y);
-        ggml_tensor * tok = ggml_new_tensor_3d(ctx0, embeddings->type, n_embd_text, 1, n_patches_y);
+        ggml_tensor * cur = ggml_reshape_3d(ctx0, embeddings, n_embd_text, p_x, p_y);
+        ggml_tensor * tok = ggml_new_tensor_3d(ctx0, embeddings->type, n_embd_text, 1, p_y);
         tok = ggml_scale(ctx0, tok, 0.0); // clear the tensor
         tok = ggml_add(ctx0, tok, model.token_embd_img_break);
         cur = ggml_concat(ctx0, cur, tok, 1);
@@ -1734,6 +1775,7 @@ struct clip_model_loader {
                 case PROJECTOR_TYPE_PIXTRAL:
                     {
                         hparams.rope_theta = 10000.0f;
+                        get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false);
                     } break;
                 case PROJECTOR_TYPE_QWEN25VL:
                     {
@@ -1957,11 +1999,14 @@ struct clip_model_loader {
             case PROJECTOR_TYPE_PIXTRAL:
                 {
                     vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
-                    vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"));
+                    vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false);
                     vision_model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
-                    vision_model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
+                    vision_model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false);
                     // [IMG_BREAK] token embedding
                     vision_model.token_embd_img_break = get_tensor(TN_TOK_IMG_BREAK);
+                    // for mistral small 3.1
+                    vision_model.mm_input_norm_w   = get_tensor(TN_MM_INP_NORM,     false);
+                    vision_model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
                 } break;
             default:
                 GGML_ASSERT(false && "unknown projector type");
@@ -2926,8 +2971,9 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
     } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
         n_patches /= ctx->vision_model.hparams.proj_scale_factor;
     } else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
-        int n_patches_x = img->nx / params.patch_size;
-        int n_patches_y = img->ny / params.patch_size;
+        int n_merge = ctx->vision_model.hparams.spatial_merge_size;
+        int n_patches_x = img->nx / params.patch_size / (n_merge > 0 ? n_merge : 1);
+        int n_patches_y = img->ny / params.patch_size / (n_merge > 0 ? n_merge : 1);
         n_patches = n_patches_y*n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
     }
 
@@ -3484,7 +3530,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
             return ctx->vision_model.mm_model_peg_0_b->ne[0];
         case PROJECTOR_TYPE_MLP:
         case PROJECTOR_TYPE_PIXTRAL:
-            return ctx->vision_model.mm_2_b->ne[0];
+            return ctx->vision_model.mm_2_w->ne[1];
         case PROJECTOR_TYPE_MLP_NORM:
             return ctx->vision_model.mm_3_b->ne[0];
         case PROJECTOR_TYPE_MINICPMV:
index 4d857ca64e0b497c41ef671389107625f79e5da8..aa52d92cab871ce7b81009acbcc6861197f92e36 100644 (file)
@@ -94,6 +94,7 @@ struct mtmd_cli_context {
             LOG_ERR("Model does not have chat template.\n");
             LOG_ERR("  For old llava models, you may need to use '--chat-template vicuna'\n");
             LOG_ERR("  For MobileVLM models, use '--chat-template deepseek'\n");
+            LOG_ERR("  For Mistral Small 3.1, use '--chat-template mistral-v7'\n");
             exit(1);
         }
 
index 75604315cfeba4e590c94c1ad34b9c61253c661f..4af370064086f94ce8bc390a1ef3f9d5c210b19f 100755 (executable)
@@ -59,6 +59,7 @@ add_test "llama-mtmd-cli"  "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M"
 
 # to test the big models, run: ./tests.sh big
 add_test_big "llama-mtmd-cli" "ggml-org/pixtral-12b-GGUF:Q4_K_M"
+add_test_big "llama-mtmd-cli" "ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF" "mistral-v7"
 
 # these models always give the wrong answer, not sure why
 # add_test "llama-mtmd-cli"  "ggml-org/SmolVLM-Instruct-GGUF:Q4_K_M"
index 326ccdb071a79f225945ebc6d0b83ebe278b0d67..a2540bd93fd91ea9fd88f1e1850e38c6fdc958aa 100644 (file)
@@ -231,6 +231,7 @@ class Keys:
         BLOCK_COUNT         = "clip.vision.block_count"
         IMAGE_MEAN          = "clip.vision.image_mean"
         IMAGE_STD           = "clip.vision.image_std"
+        SPATIAL_MERGE_SIZE  = "clip.vision.spatial_merge_size"
         USE_GELU            = "clip.use_gelu"
         USE_SILU            = "clip.use_silu"
 
@@ -491,6 +492,7 @@ class MODEL_TENSOR(IntEnum):
     V_ENC_FFN_DOWN       = auto()
     V_PRE_NORM           = auto()
     V_POST_NORM          = auto()
+    V_MM_INP_NORM        = auto()
     V_MM_INP_PROJ        = auto() # gemma3
     V_MM_SOFT_EMB_NORM   = auto() # gemma3
     V_RESMPL_POS_EMBD_K  = auto() # minicpmv
@@ -505,6 +507,7 @@ class MODEL_TENSOR(IntEnum):
     V_RESMPL_PROJ        = auto() # minicpmv
     V_RESMPL_QUERY       = auto() # minicpmv
     V_TOK_EMBD_IMG_BREAK = auto() # pixtral
+    V_MM_PATCH_MERGER    = auto() # mistral small 3.1
 
 
 MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
@@ -747,6 +750,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
     MODEL_TENSOR.V_PRE_NORM:                "v.pre_ln",
     MODEL_TENSOR.V_POST_NORM:               "v.post_ln",
     MODEL_TENSOR.V_MM_INP_PROJ:             "mm.input_projection",
+    MODEL_TENSOR.V_MM_INP_NORM:             "mm.input_norm",
     MODEL_TENSOR.V_MM_SOFT_EMB_NORM:        "mm.soft_emb_norm",
     MODEL_TENSOR.V_RESMPL_POS_EMBD_K:       "resampler.pos_embd_k",
     MODEL_TENSOR.V_RESMPL_ATTN_Q:           "resampler.attn.q",
@@ -760,6 +764,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
     MODEL_TENSOR.V_RESMPL_PROJ:             "resampler.proj",
     MODEL_TENSOR.V_RESMPL_QUERY:            "resampler.query",
     MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK:      "v.token_embd.img_break", # pixtral
+    MODEL_TENSOR.V_MM_PATCH_MERGER:         "mm.patch_merger", # mistral small 3.1
 }
 
 MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@@ -783,6 +788,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.V_PRE_NORM,
         MODEL_TENSOR.V_POST_NORM,
         MODEL_TENSOR.V_MM_INP_PROJ,
+        MODEL_TENSOR.V_MM_INP_NORM,
         MODEL_TENSOR.V_MM_SOFT_EMB_NORM,
         MODEL_TENSOR.V_RESMPL_POS_EMBD_K,
         MODEL_TENSOR.V_RESMPL_ATTN_Q,
@@ -796,6 +802,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.V_RESMPL_PROJ,
         MODEL_TENSOR.V_RESMPL_QUERY,
         MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK,
+        MODEL_TENSOR.V_MM_PATCH_MERGER,
     ],
     MODEL_ARCH.LLAMA: [
         MODEL_TENSOR.TOKEN_EMBD,
index f22a6d4a3472be0b13c89fbdc67068041f94071d..a30c49e32b35180a2b9df82b7761370f4dac86b1 100644 (file)
@@ -972,6 +972,9 @@ class GGUFWriter:
     def add_vision_image_std(self, values: Sequence[float]) -> None:
         self.add_array(Keys.ClipVision.IMAGE_STD, values)
 
+    def add_vision_spatial_merge_size(self, value: int) -> None:
+        self.add_uint32(Keys.ClipVision.SPATIAL_MERGE_SIZE, value)
+
     def add_vision_use_gelu(self, value: bool) -> None:
         self.add_bool(Keys.ClipVision.USE_GELU, value)
 
index 311d1ff69c7999d25db65de948adde86e4ee194e..2f6326104ffa70f5aa1e5e7e3933921683f272a0 100644 (file)
@@ -1001,6 +1001,10 @@ class TensorNameMap:
             "multi_modal_projector.mm_input_projection",
         ),
 
+        MODEL_TENSOR.V_MM_INP_NORM: (
+            "multi_modal_projector.norm",
+        ),
+
         MODEL_TENSOR.V_MM_SOFT_EMB_NORM: (
             "multi_modal_projector.mm_soft_emb_norm",
         ),
@@ -1052,6 +1056,10 @@ class TensorNameMap:
         MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: (
             "v.token_embd.img_break", # for pixtral, this is a generated vector
         ),
+
+        MODEL_TENSOR.V_MM_PATCH_MERGER: (
+            "multi_modal_projector.patch_merger.merging_layer", # mistral small 3.1
+        ),
     }
 
     # architecture-specific block mappings