From: Xuan-Son Nguyen Date: Thu, 1 May 2025 15:05:42 +0000 (+0200) Subject: mtmd : add **vision** support for Mistral Small 3.1 (#13231) X-Git-Tag: upstream/0.0.5318~75 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=8936784f7a1ec4f91637d04b77fdc90ec36ebac9;p=pkg%2Fggml%2Fsources%2Fllama.cpp mtmd : add **vision** support for Mistral Small 3.1 (#13231) * convert ok * load ok, missing patch merger * ah sheet it works * update llava/readme * add test * fix test --- diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 123df801..04ca646b 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1899,7 +1899,10 @@ class LlamaModel(TextModel): raise ValueError(f"Unprocessed experts: {experts}") -@ModelBase.register("LlavaForConditionalGeneration") +@ModelBase.register( + "LlavaForConditionalGeneration", # pixtral + "Mistral3ForConditionalGeneration", # mistral small 3.1 +) class LlavaVisionModel(VisionModel): img_break_tok_id = -1 @@ -1908,17 +1911,38 @@ class LlavaVisionModel(VisionModel): if self.hparams["model_type"] == "pixtral": # layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5) - self.img_break_tok_id = 12 # see tokenizer_config.json + self.img_break_tok_id = self.get_token_id("[IMG_BREAK]") + logger.info(f"Image break token id: {self.img_break_tok_id}") else: raise ValueError(f"Unsupported model type: {self.hparams['model_type']}") + def get_token_id(self, token: str) -> int: + tokenizer_config_file = self.dir_model / 'tokenizer_config.json' + with open(tokenizer_config_file, "r", encoding="utf-8") as f: + added_tokens_decoder = json.load(f)['added_tokens_decoder'] + for id_, token_data in added_tokens_decoder.items(): + if token_data["content"] == token: + return int(id_) + raise ValueError(f"Token '{token}' not found in tokenizer config.") + def set_gguf_parameters(self): super().set_gguf_parameters() hparams = self.hparams if hparams["model_type"] == "pixtral": self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.PIXTRAL) self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"]) - self.gguf_writer.add_vision_use_silu(True) + + # hidden_act + if hparams["hidden_act"] == "silu": + self.gguf_writer.add_vision_use_silu(True) + elif hparams["hidden_act"] == "gelu": + self.gguf_writer.add_vision_use_gelu(True) + else: + raise ValueError(f"Unsupported hidden_act: {hparams['hidden_act']}") + + # spatial_merge_size + if "spatial_merge_size" in self.global_config: + self.gguf_writer.add_vision_spatial_merge_size(self.global_config["spatial_merge_size"]) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: del bid # unused diff --git a/examples/llava/README.md b/examples/llava/README.md index f58d9de7..3b62627c 100644 --- a/examples/llava/README.md +++ b/examples/llava/README.md @@ -34,6 +34,9 @@ llama-mtmd-cli -hf ggml-org/SmolVLM2-500M-Video-Instruct-GGUF # Pixtral 12B llama-mtmd-cli -hf ggml-org/pixtral-12b-GGUF + +# Mistral Small 3.1 24B (IQ2_M quantization) +llama-mtmd-cli -hf ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF --chat-template mistral-v7 ``` ## How it works and what is `mmproj`? @@ -73,3 +76,4 @@ For the following models, you can use `convert_hf_to_gguf.py`with `--mmproj` fla - SmolVLM (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB)) - SmolVLM2 (from [HuggingFaceTB](https://huggingface.co/HuggingFaceTB)) - [Pixtral 12B](https://huggingface.co/mistral-community/pixtral-12b) - only works with `transformers`-compatible checkpoint +- [Mistral Small 3.1 24B](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503) diff --git a/examples/llava/clip-impl.h b/examples/llava/clip-impl.h index 66cb21ef..b575ca4d 100644 --- a/examples/llava/clip-impl.h +++ b/examples/llava/clip-impl.h @@ -31,6 +31,7 @@ #define KEY_FEATURE_LAYER "clip.vision.feature_layer" #define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor" #define KEY_PROJ_TYPE "clip.projector_type" +#define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size" #define KEY_USE_GLU_MLP "clip.use_glu_mlp" // for qwen2.5vl #define KEY_USE_RMS_NORM "clip.use_rms_norm" // for qwen2.5vl @@ -68,9 +69,11 @@ #define TN_MVLM_PROJ_BLOCK "mm.model.mb_block.%d.block.%d.%s" #define TN_MVLM_PROJ_PEG "mm.model.peg.%d.%s" #define TN_IMAGE_NEWLINE "model.image_newline" +#define TN_MM_INP_NORM "mm.input_norm.weight" #define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3 #define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3 #define TN_MM_PROJECTOR "mm.model.fc.weight" // idefics3 +#define TN_MM_PATCH_MERGER "mm.patch_merger.weight" // mistral small 3.1 #define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral // mimicpmv diff --git a/examples/llava/clip.cpp b/examples/llava/clip.cpp index ad3e7df1..984e300e 100644 --- a/examples/llava/clip.cpp +++ b/examples/llava/clip.cpp @@ -172,6 +172,7 @@ struct clip_hparams { std::unordered_set vision_feature_layer; int32_t attn_window_size = 0; int32_t n_wa_pattern = 0; + int32_t spatial_merge_size = 0; }; struct clip_layer { @@ -232,6 +233,7 @@ struct clip_vision_model { struct ggml_tensor * projection; // LLaVA projection + struct ggml_tensor * mm_input_norm_w = nullptr; struct ggml_tensor * mm_0_w = nullptr; struct ggml_tensor * mm_0_b = nullptr; struct ggml_tensor * mm_2_w = nullptr; @@ -311,6 +313,7 @@ struct clip_vision_model { // pixtral struct ggml_tensor * token_embd_img_break = nullptr; + struct ggml_tensor * mm_patch_merger_w = nullptr; }; struct clip_ctx { @@ -637,6 +640,7 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i const int d_head = hidden_size / n_head; const int n_layer = hparams.n_layer; const float eps = hparams.eps; + const int n_merge = hparams.spatial_merge_size; struct ggml_init_params params = { /*.mem_size =*/ ctx->buf_compute_meta.size(), @@ -721,7 +725,13 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i { ggml_tensor * gate_proj = ggml_mul_mat(ctx0, model.layers[il].ff_gate_w, cur); ggml_tensor * up_proj = ggml_mul_mat(ctx0, model.layers[il].ff_up_w, cur); - gate_proj = ggml_silu(ctx0, gate_proj); // pixtral uses silu + if (ctx->use_silu) { + gate_proj = ggml_silu(ctx0, gate_proj); + } else if (ctx->use_gelu) { + gate_proj = ggml_gelu(ctx0, gate_proj); + } else { + GGML_ABORT("Pixtral: Unsupported activation"); + } cur = ggml_mul(ctx0, up_proj, gate_proj); cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur); } @@ -732,14 +742,42 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i embeddings = cur; } - // LlavaMultiModalProjector (with GELU activation) + // mistral small 3.1 patch merger + // ref: https://github.com/huggingface/transformers/blob/7a3e208892c06a5e278144eaf38c8599a42f53e7/src/transformers/models/mistral3/modeling_mistral3.py#L67 + if (model.mm_patch_merger_w) { + GGML_ASSERT(hparams.spatial_merge_size > 0); + + ggml_tensor * cur = embeddings; + cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.mm_input_norm_w); + + // reshape image tokens to 2D grid + cur = ggml_reshape_3d(ctx0, cur, hidden_size, n_patches_x, n_patches_y); + cur = ggml_permute(ctx0, cur, 2, 0, 1, 3); // [x, y, hidden_size] + cur = ggml_cont(ctx0, cur); + + // torch.nn.functional.unfold is just an im2col under the hood + // we just need a dummy kernel to make it work + ggml_tensor * kernel = ggml_view_3d(ctx0, cur, n_merge, n_merge, cur->ne[2], 0, 0, 0); + cur = ggml_im2col(ctx0, kernel, cur, n_merge, n_merge, 0, 0, 1, 1, true, inp->type); + + // project to hidden_size + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]); + cur = ggml_mul_mat(ctx0, model.mm_patch_merger_w, cur); + embeddings = cur; + } + + // LlavaMultiModalProjector (always using GELU activation) { embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings); - embeddings = ggml_add(ctx0, embeddings, model.mm_1_b); + if (model.mm_1_b) { + embeddings = ggml_add(ctx0, embeddings, model.mm_1_b); + } embeddings = ggml_gelu(ctx0, embeddings); embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings); - embeddings = ggml_add(ctx0, embeddings, model.mm_2_b); + if (model.mm_2_b) { + embeddings = ggml_add(ctx0, embeddings, model.mm_2_b); + } } // arrangement of the [IMG_BREAK] token @@ -749,11 +787,14 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i // and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension // after the concatenation, we have a tensor with shape [hidden_size, n_patches_per_row + 1, n_rows] + const int p_y = n_merge > 0 ? n_patches_y / n_merge : n_patches_y; + const int p_x = n_merge > 0 ? n_patches_x / n_merge : n_patches_x; + const int p_total = p_x * p_y; const int n_embd_text = embeddings->ne[0]; - const int n_tokens_output = num_patches + n_patches_y - 1; // one [IMG_BREAK] per row, except the last row + const int n_tokens_output = p_total + p_y - 1; // one [IMG_BREAK] per row, except the last row - ggml_tensor * cur = ggml_reshape_3d(ctx0, embeddings, n_embd_text, n_patches_x, n_patches_y); - ggml_tensor * tok = ggml_new_tensor_3d(ctx0, embeddings->type, n_embd_text, 1, n_patches_y); + ggml_tensor * cur = ggml_reshape_3d(ctx0, embeddings, n_embd_text, p_x, p_y); + ggml_tensor * tok = ggml_new_tensor_3d(ctx0, embeddings->type, n_embd_text, 1, p_y); tok = ggml_scale(ctx0, tok, 0.0); // clear the tensor tok = ggml_add(ctx0, tok, model.token_embd_img_break); cur = ggml_concat(ctx0, cur, tok, 1); @@ -1734,6 +1775,7 @@ struct clip_model_loader { case PROJECTOR_TYPE_PIXTRAL: { hparams.rope_theta = 10000.0f; + get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false); } break; case PROJECTOR_TYPE_QWEN25VL: { @@ -1957,11 +1999,14 @@ struct clip_model_loader { case PROJECTOR_TYPE_PIXTRAL: { vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight")); - vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias")); + vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"), false); vision_model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); - vision_model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); + vision_model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"), false); // [IMG_BREAK] token embedding vision_model.token_embd_img_break = get_tensor(TN_TOK_IMG_BREAK); + // for mistral small 3.1 + vision_model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false); + vision_model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false); } break; default: GGML_ASSERT(false && "unknown projector type"); @@ -2926,8 +2971,9 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) { n_patches /= ctx->vision_model.hparams.proj_scale_factor; } else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) { - int n_patches_x = img->nx / params.patch_size; - int n_patches_y = img->ny / params.patch_size; + int n_merge = ctx->vision_model.hparams.spatial_merge_size; + int n_patches_x = img->nx / params.patch_size / (n_merge > 0 ? n_merge : 1); + int n_patches_y = img->ny / params.patch_size / (n_merge > 0 ? n_merge : 1); n_patches = n_patches_y*n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row } @@ -3484,7 +3530,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->vision_model.mm_model_peg_0_b->ne[0]; case PROJECTOR_TYPE_MLP: case PROJECTOR_TYPE_PIXTRAL: - return ctx->vision_model.mm_2_b->ne[0]; + return ctx->vision_model.mm_2_w->ne[1]; case PROJECTOR_TYPE_MLP_NORM: return ctx->vision_model.mm_3_b->ne[0]; case PROJECTOR_TYPE_MINICPMV: diff --git a/examples/llava/mtmd-cli.cpp b/examples/llava/mtmd-cli.cpp index 4d857ca6..aa52d92c 100644 --- a/examples/llava/mtmd-cli.cpp +++ b/examples/llava/mtmd-cli.cpp @@ -94,6 +94,7 @@ struct mtmd_cli_context { LOG_ERR("Model does not have chat template.\n"); LOG_ERR(" For old llava models, you may need to use '--chat-template vicuna'\n"); LOG_ERR(" For MobileVLM models, use '--chat-template deepseek'\n"); + LOG_ERR(" For Mistral Small 3.1, use '--chat-template mistral-v7'\n"); exit(1); } diff --git a/examples/llava/tests.sh b/examples/llava/tests.sh index 75604315..4af37006 100755 --- a/examples/llava/tests.sh +++ b/examples/llava/tests.sh @@ -59,6 +59,7 @@ add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M" # to test the big models, run: ./tests.sh big add_test_big "llama-mtmd-cli" "ggml-org/pixtral-12b-GGUF:Q4_K_M" +add_test_big "llama-mtmd-cli" "ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF" "mistral-v7" # these models always give the wrong answer, not sure why # add_test "llama-mtmd-cli" "ggml-org/SmolVLM-Instruct-GGUF:Q4_K_M" diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 326ccdb0..a2540bd9 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -231,6 +231,7 @@ class Keys: BLOCK_COUNT = "clip.vision.block_count" IMAGE_MEAN = "clip.vision.image_mean" IMAGE_STD = "clip.vision.image_std" + SPATIAL_MERGE_SIZE = "clip.vision.spatial_merge_size" USE_GELU = "clip.use_gelu" USE_SILU = "clip.use_silu" @@ -491,6 +492,7 @@ class MODEL_TENSOR(IntEnum): V_ENC_FFN_DOWN = auto() V_PRE_NORM = auto() V_POST_NORM = auto() + V_MM_INP_NORM = auto() V_MM_INP_PROJ = auto() # gemma3 V_MM_SOFT_EMB_NORM = auto() # gemma3 V_RESMPL_POS_EMBD_K = auto() # minicpmv @@ -505,6 +507,7 @@ class MODEL_TENSOR(IntEnum): V_RESMPL_PROJ = auto() # minicpmv V_RESMPL_QUERY = auto() # minicpmv V_TOK_EMBD_IMG_BREAK = auto() # pixtral + V_MM_PATCH_MERGER = auto() # mistral small 3.1 MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -747,6 +750,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.V_PRE_NORM: "v.pre_ln", MODEL_TENSOR.V_POST_NORM: "v.post_ln", MODEL_TENSOR.V_MM_INP_PROJ: "mm.input_projection", + MODEL_TENSOR.V_MM_INP_NORM: "mm.input_norm", MODEL_TENSOR.V_MM_SOFT_EMB_NORM: "mm.soft_emb_norm", MODEL_TENSOR.V_RESMPL_POS_EMBD_K: "resampler.pos_embd_k", MODEL_TENSOR.V_RESMPL_ATTN_Q: "resampler.attn.q", @@ -760,6 +764,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.V_RESMPL_PROJ: "resampler.proj", MODEL_TENSOR.V_RESMPL_QUERY: "resampler.query", MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: "v.token_embd.img_break", # pixtral + MODEL_TENSOR.V_MM_PATCH_MERGER: "mm.patch_merger", # mistral small 3.1 } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -783,6 +788,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.V_PRE_NORM, MODEL_TENSOR.V_POST_NORM, MODEL_TENSOR.V_MM_INP_PROJ, + MODEL_TENSOR.V_MM_INP_NORM, MODEL_TENSOR.V_MM_SOFT_EMB_NORM, MODEL_TENSOR.V_RESMPL_POS_EMBD_K, MODEL_TENSOR.V_RESMPL_ATTN_Q, @@ -796,6 +802,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.V_RESMPL_PROJ, MODEL_TENSOR.V_RESMPL_QUERY, MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK, + MODEL_TENSOR.V_MM_PATCH_MERGER, ], MODEL_ARCH.LLAMA: [ MODEL_TENSOR.TOKEN_EMBD, diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index f22a6d4a..a30c49e3 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -972,6 +972,9 @@ class GGUFWriter: def add_vision_image_std(self, values: Sequence[float]) -> None: self.add_array(Keys.ClipVision.IMAGE_STD, values) + def add_vision_spatial_merge_size(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.SPATIAL_MERGE_SIZE, value) + def add_vision_use_gelu(self, value: bool) -> None: self.add_bool(Keys.ClipVision.USE_GELU, value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 311d1ff6..2f632610 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1001,6 +1001,10 @@ class TensorNameMap: "multi_modal_projector.mm_input_projection", ), + MODEL_TENSOR.V_MM_INP_NORM: ( + "multi_modal_projector.norm", + ), + MODEL_TENSOR.V_MM_SOFT_EMB_NORM: ( "multi_modal_projector.mm_soft_emb_norm", ), @@ -1052,6 +1056,10 @@ class TensorNameMap: MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK: ( "v.token_embd.img_break", # for pixtral, this is a generated vector ), + + MODEL_TENSOR.V_MM_PATCH_MERGER: ( + "multi_modal_projector.patch_merger.merging_layer", # mistral small 3.1 + ), } # architecture-specific block mappings