From: Xuan-Son Nguyen Date: Sat, 10 May 2025 14:26:42 +0000 (+0200) Subject: mtmd : support InternVL 2.5 and 3 (#13422) X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=053367d149f778cdabc356ee3024494e0dd53223;p=pkg%2Fggml%2Fsources%2Fllama.cpp mtmd : support InternVL 2.5 and 3 (#13422) * convert : internvl support * InternVL3-1B working * fix regression * rm mobilevlm from test * fix conversion * add test for internvl * add to list of pre-quant * restore boi/eoi check * add clarify comment for norm eps --- diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index bf6bc683..e5c397fe 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -426,7 +426,11 @@ class ModelBase: logger.warning(f"Failed to load model config from {dir_model}: {e}") logger.warning("Trying to load config.json instead") with open(dir_model / "config.json", "r", encoding="utf-8") as f: - return json.load(f) + config = json.load(f) + if "llm_config" in config: + # rename for InternVL + config["text_config"] = config["llm_config"] + return config @classmethod def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]: @@ -2606,6 +2610,11 @@ class Qwen2Model(TextModel): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: if self.hf_arch == "Qwen2Model": name = f"model.{name}" # map to Qwen2ForCausalLM tensors + if "language_model." in name: + name = name.replace("language_model.", "") # for InternVL + if name.startswith("mlp") or name.startswith("vision_model"): + # skip visual tensors + return [] yield from super().modify_tensors(data_torch, name, bid) @@ -2709,6 +2718,62 @@ class Qwen2VLVisionModel(VisionModel): return [] # skip other tensors +@ModelBase.register("InternVisionModel") +class InternVisionModel(VisionModel): + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.INTERNVL) + self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"]) + # hidden_act + if hparams["hidden_act"] == "silu": + self.gguf_writer.add_vision_use_silu(True) + elif hparams["hidden_act"] == "gelu": + self.gguf_writer.add_vision_use_gelu(True) + else: + raise ValueError(f"Unsupported hidden_act: {hparams['hidden_act']}") + # downsample_ratio + downsample_ratio = self.global_config.get("downsample_ratio") + assert downsample_ratio is not None + self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / downsample_ratio)) + + def tensor_force_quant(self, name, new_name, bid, n_dims): + del bid, name, n_dims # unused + if ".patch_embd." in new_name: + return gguf.GGMLQuantizationType.F16 + if ".position_embd." in new_name: + return gguf.GGMLQuantizationType.F32 + return False + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + if name.startswith("vision_model") or name.startswith("mlp"): + # process visual tensors + # correct name + if name.startswith("vision_model"): + name = "vision_tower." + name + if (".ls" in name or "position_embedding" in name) and not name.endswith(".weight"): + name += ".weight" + # split QKV tensors if needed + if ".qkv." in name: + if data_torch.ndim == 2: # weight + c3, _ = data_torch.shape + else: # bias + c3 = data_torch.shape[0] + assert c3 % 3 == 0 + c = c3 // 3 + wq = data_torch[:c] + wk = data_torch[c: c * 2] + wv = data_torch[c * 2:] + return [ + (self.map_tensor_name(name.replace("attn.qkv", "self_attn.q_proj")), wq), + (self.map_tensor_name(name.replace("attn.qkv", "self_attn.k_proj")), wk), + (self.map_tensor_name(name.replace("attn.qkv", "self_attn.v_proj")), wv), + ] + return [(self.map_tensor_name(name), data_torch)] + return [] # skip other tensors + + @ModelBase.register("WavTokenizerDec") class WavTokenizerDecModel(TextModel): model_arch = gguf.MODEL_ARCH.WAVTOKENIZER_DEC @@ -3360,6 +3425,11 @@ class InternLM2Model(TextModel): head_dim = n_embd // num_heads num_groups = num_heads // q_per_kv + name = name.replace("language_model.", "") # InternVL + if name.startswith("mlp") or name.startswith("vision_model"): + # skip visual tensors + return [] + if bid is not None and f"model.layers.{bid}.attention.wqkv" in name: qkv = data_torch @@ -3433,6 +3503,10 @@ class InternLM3Model(TextModel): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: n_head = self.hparams["num_attention_heads"] n_kv_head = self.hparams.get("num_key_value_heads") + name = name.replace("language_model.", "") # InternVL + if name.startswith("mlp") or name.startswith("vision_model"): + # skip visual tensors + return [] if name.endswith(("q_proj.weight", "q_proj.bias")): data_torch = LlamaModel.permute(data_torch, n_head, n_head) if name.endswith(("k_proj.weight", "k_proj.bias")): diff --git a/docs/multimodal.md b/docs/multimodal.md index efed473a..090583f9 100644 --- a/docs/multimodal.md +++ b/docs/multimodal.md @@ -66,4 +66,12 @@ NOTE: some models may require large context window, for example: `-c 8192` # Mistral Small 3.1 24B (IQ2_M quantization) (tool_name) -hf ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF + +# InternVL 2.5 and 3 +(tool_name) -hf ggml-org/InternVL2_5-1B-GGUF +(tool_name) -hf ggml-org/InternVL2_5-2B-GGUF +(tool_name) -hf ggml-org/InternVL3-1B-Instruct-GGUF +(tool_name) -hf ggml-org/InternVL3-2B-Instruct-GGUF +(tool_name) -hf ggml-org/InternVL3-4B-Instruct-GGUF +(tool_name) -hf ggml-org/InternVL3-14B-Instruct-GGUF ``` diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 7dd7bb6d..ae5ce71a 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -491,6 +491,8 @@ class MODEL_TENSOR(IntEnum): V_ENC_FFN_UP = auto() V_ENC_FFN_GATE = auto() V_ENC_FFN_DOWN = auto() + V_LAYER_SCALE_1 = auto() + V_LAYER_SCALE_2 = auto() V_PRE_NORM = auto() V_POST_NORM = auto() V_MM_INP_NORM = auto() @@ -748,6 +750,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.V_ENC_FFN_UP: "v.blk.{bid}.ffn_up", MODEL_TENSOR.V_ENC_FFN_GATE: "v.blk.{bid}.ffn_gate", MODEL_TENSOR.V_ENC_FFN_DOWN: "v.blk.{bid}.ffn_down", + MODEL_TENSOR.V_LAYER_SCALE_1: "v.blk.{bid}.ls1", + MODEL_TENSOR.V_LAYER_SCALE_2: "v.blk.{bid}.ls2", MODEL_TENSOR.V_PRE_NORM: "v.pre_ln", MODEL_TENSOR.V_POST_NORM: "v.post_ln", MODEL_TENSOR.V_MM_INP_PROJ: "mm.input_projection", @@ -786,6 +790,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.V_ENC_FFN_UP, MODEL_TENSOR.V_ENC_FFN_GATE, MODEL_TENSOR.V_ENC_FFN_DOWN, + MODEL_TENSOR.V_LAYER_SCALE_1, + MODEL_TENSOR.V_LAYER_SCALE_2, MODEL_TENSOR.V_PRE_NORM, MODEL_TENSOR.V_POST_NORM, MODEL_TENSOR.V_MM_INP_PROJ, @@ -2167,6 +2173,7 @@ class VisionProjectorType: PIXTRAL = "pixtral" QWEN2VL = "qwen2vl_merger" QWEN25VL = "qwen2.5vl_merger" + INTERNVL = "internvl" # Items here are (block size, type size) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 003b0172..bf7ec325 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -905,6 +905,7 @@ class TensorNameMap: MODEL_TENSOR.V_MMPROJ_MLP: ( "model.mm_projector.mlp.mlp.{bid}", + "mlp1.{bid}", # InternVL ), MODEL_TENSOR.V_MMPROJ_PEG: ( @@ -955,6 +956,7 @@ class TensorNameMap: MODEL_TENSOR.V_ENC_INPUT_NORM: ( "vision_tower.vision_model.encoder.layers.{bid}.layer_norm1", + "vision_tower.vision_model.encoder.layers.{bid}.norm1", # InternVL "vpm.encoder.layers.{bid}.layer_norm1", "model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM "vision_tower.transformer.layers.{bid}.attention_norm", # pixtral @@ -963,6 +965,7 @@ class TensorNameMap: MODEL_TENSOR.V_ENC_OUTPUT: ( "vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj", + "vision_tower.vision_model.encoder.layers.{bid}.attn.proj", # InternVL "vpm.encoder.layers.{bid}.self_attn.out_proj", "model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM "vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral @@ -971,6 +974,7 @@ class TensorNameMap: MODEL_TENSOR.V_ENC_OUTPUT_NORM: ( "vision_tower.vision_model.encoder.layers.{bid}.layer_norm2", + "vision_tower.vision_model.encoder.layers.{bid}.norm2", # InternVL "vpm.encoder.layers.{bid}.layer_norm2", "model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM "vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral @@ -1000,6 +1004,14 @@ class TensorNameMap: "visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl ), + MODEL_TENSOR.V_LAYER_SCALE_1: ( + "vision_tower.vision_model.encoder.layers.{bid}.ls1", # InternVL + ), + + MODEL_TENSOR.V_LAYER_SCALE_2: ( + "vision_tower.vision_model.encoder.layers.{bid}.ls2", # InternVL + ), + MODEL_TENSOR.V_PRE_NORM: ( "vision_tower.vision_model.pre_layrnorm", "vision_tower.ln_pre", # pixtral diff --git a/tools/mtmd/README.md b/tools/mtmd/README.md index 06e1fd09..ab258ea1 100644 --- a/tools/mtmd/README.md +++ b/tools/mtmd/README.md @@ -48,6 +48,7 @@ For the following models, you can use `convert_hf_to_gguf.py`with `--mmproj` fla - [Pixtral 12B](https://huggingface.co/mistral-community/pixtral-12b) - only works with `transformers`-compatible checkpoint - Qwen 2 VL and Qwen 2.5 VL (from [Qwen](https://huggingface.co/Qwen)) - [Mistral Small 3.1 24B](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503) +- InternVL 2.5 and InternVL 3 from [OpenGVLab](https://huggingface.co/OpenGVLab) (note: we don't support conversion of `InternVL3-*-hf` model, only non-HF version is supported ; `InternLM2Model` **text** model is not supported) For older models, please refer to the relevant guide for instructions on how to obtain or create them: diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index fb780e9d..e9c8646e 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -33,9 +33,6 @@ #define KEY_PROJ_TYPE "clip.projector_type" #define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size" -#define KEY_USE_GLU_MLP "clip.use_glu_mlp" // for qwen2.5vl -#define KEY_USE_RMS_NORM "clip.use_rms_norm" // for qwen2.5vl - #define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" #define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" #define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution" @@ -60,8 +57,10 @@ #define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s" #define TN_FFN_UP "%s.blk.%d.ffn_up.%s" #define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s" -#define TN_LN_1 "%s.blk.%d.ln1.%s" -#define TN_LN_2 "%s.blk.%d.ln2.%s" +#define TN_LN_1 "%s.blk.%d.ln1.%s" // layer norm +#define TN_LN_2 "%s.blk.%d.ln2.%s" // layer norm +#define TN_LS_1 "%s.blk.%d.ls1.%s" // layer scale +#define TN_LS_2 "%s.blk.%d.ls2.%s" // layer scale #define TN_LN_PRE "%s.pre_ln.%s" #define TN_LN_POST "%s.post_ln.%s" #define TN_LLAVA_PROJ "mm.%d.%s" @@ -105,6 +104,7 @@ enum projector_type { PROJECTOR_TYPE_IDEFICS3, PROJECTOR_TYPE_PIXTRAL, PROJECTOR_TYPE_QWEN25VL, + PROJECTOR_TYPE_INTERNVL, PROJECTOR_TYPE_UNKNOWN, }; @@ -119,6 +119,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_GEMMA3, "gemma3"}, { PROJECTOR_TYPE_IDEFICS3, "idefics3"}, { PROJECTOR_TYPE_PIXTRAL, "pixtral"}, + { PROJECTOR_TYPE_INTERNVL, "internvl"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 1a81c1fc..dfe7ac91 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -215,6 +215,10 @@ struct clip_layer { // layernorm 2 ggml_tensor * ln_2_w = nullptr; ggml_tensor * ln_2_b = nullptr; + + // layer scale (no bias) + ggml_tensor * ls_1_w = nullptr; + ggml_tensor * ls_2_w = nullptr; }; struct clip_vision_model { @@ -589,6 +593,9 @@ struct clip_graph { // Qwen2VL and Qwen2.5VL use M-RoPE ggml_cgraph * build_qwen2vl() { + GGML_ASSERT(model.patch_bias == nullptr); + GGML_ASSERT(model.class_embedding == nullptr); + const int batch_size = 1; const bool use_window_attn = hparams.n_wa_pattern > 0; const int n_wa_pattern = hparams.n_wa_pattern; @@ -625,10 +632,6 @@ struct clip_graph { n_embd, n_patches_x * n_patches_y, batch_size); } - if (model.patch_bias) { - inp = ggml_add(ctx0, inp, model.patch_bias); - } - ggml_tensor * inpL = inp; ggml_tensor * window_mask = nullptr; ggml_tensor * window_idx = nullptr; @@ -859,6 +862,67 @@ struct clip_graph { return gf; } + ggml_cgraph * build_internvl() { + GGML_ASSERT(model.class_embedding != nullptr); + GGML_ASSERT(model.position_embeddings != nullptr); + + const int n_pos = n_patches + 1; + ggml_tensor * inp = build_inp(); + + // add CLS token + inp = ggml_concat(ctx0, inp, model.class_embedding, 1); + + ggml_tensor * cur = build_vit( + inp, n_pos, + NORM_TYPE_NORMAL, + hparams.ffn_op, + model.position_embeddings, + nullptr); + + // remove CLS token + cur = ggml_view_2d(ctx0, cur, + n_embd, n_patches, + ggml_row_size(cur->type, n_embd), 0); + + // pixel shuffle + { + const int scale_factor = model.hparams.proj_scale_factor; + const int bsz = 1; // batch size, always 1 for now since we don't support batching + const int height = n_patches_y; + const int width = n_patches_x; + GGML_ASSERT(scale_factor > 0); + cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, height / scale_factor, width, bsz); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur), + n_embd * scale_factor * scale_factor, + height / scale_factor, + width / scale_factor, + bsz); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + // flatten to 2D + cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, cur), + n_embd * scale_factor * scale_factor, + cur->ne[1] * cur->ne[2]); + } + + // projector (always using GELU activation) + { + // projector LayerNorm uses pytorch's default eps = 1e-5 + // ref: https://huggingface.co/OpenGVLab/InternVL3-8B-Instruct/blob/a34d3e4e129a5856abfd6aa6de79776484caa14e/modeling_internvl_chat.py#L79 + cur = build_norm(cur, model.mm_0_w, model.mm_0_b, NORM_TYPE_NORMAL, 1e-5, -1); + cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); + cur = ggml_add(ctx0, cur, model.mm_1_b); + cur = ggml_gelu(ctx0, cur); + cur = ggml_mul_mat(ctx0, model.mm_3_w, cur); + cur = ggml_add(ctx0, cur, model.mm_3_b); + } + + // build the graph + ggml_build_forward_expand(gf, cur); + + return gf; + } + // this graph is used by llava, granite and glm // due to having embedding_stack (used by granite), we cannot reuse build_vit ggml_cgraph * build_llava() { @@ -890,10 +954,6 @@ struct clip_graph { ggml_tensor * inp = build_inp(); - if (model.patch_bias) { - inp = ggml_add(ctx0, inp, model.patch_bias); - } - // concat class_embeddings and patch_embeddings if (model.class_embedding) { inp = ggml_concat(ctx0, inp, model.class_embedding, 1); @@ -1260,11 +1320,6 @@ private: ggml_tensor * learned_pos_embd, std::function add_pos ) { - if (model.patch_bias) { - inp = ggml_add(ctx0, inp, model.patch_bias); - cb(inp, "patch_bias", -1); - } - if (learned_pos_embd) { inp = ggml_add(ctx0, inp, learned_pos_embd); cb(inp, "pos_embed", -1); @@ -1324,6 +1379,11 @@ private: cb(cur, "attn_out", il); } + if (layer.ls_1_w) { + cur = ggml_mul(ctx0, cur, layer.ls_1_w); + cb(cur, "attn_out_scaled", il); + } + // re-add the layer input, e.g., residual cur = ggml_add(ctx0, cur, inpL); @@ -1344,6 +1404,11 @@ private: cb(cur, "ffn_out", il); + if (layer.ls_2_w) { + cur = ggml_mul(ctx0, cur, layer.ls_2_w); + cb(cur, "ffn_out_scaled", il); + } + // residual 2 cur = ggml_add(ctx0, inpL, cur); cb(cur, "layer_out", il); @@ -1365,6 +1430,10 @@ private: ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd); inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); + if (model.patch_bias) { + inp = ggml_add(ctx0, inp, model.patch_bias); + cb(inp, "patch_bias", -1); + } return inp; } @@ -1627,6 +1696,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { res = graph.build_minicpmv(); } break; + case PROJECTOR_TYPE_INTERNVL: + { + res = graph.build_internvl(); + } break; default: { res = graph.build_llava(); @@ -1790,6 +1863,7 @@ struct clip_model_loader { } } break; case PROJECTOR_TYPE_IDEFICS3: + case PROJECTOR_TYPE_INTERNVL: { get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false); } break; @@ -1897,6 +1971,9 @@ struct clip_model_loader { layer.o_w = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "weight")); layer.ln_1_w = get_tensor(string_format(TN_LN_1, "v", il, "weight"), false); layer.ln_2_w = get_tensor(string_format(TN_LN_2, "v", il, "weight"), false); + layer.ls_1_w = get_tensor(string_format(TN_LS_1, "v", il, "weight"), false); // no bias + layer.ls_2_w = get_tensor(string_format(TN_LS_2, "v", il, "weight"), false); // no bias + layer.k_b = get_tensor(string_format(TN_ATTN_K, "v", il, "bias"), false); layer.q_b = get_tensor(string_format(TN_ATTN_Q, "v", il, "bias"), false); layer.v_b = get_tensor(string_format(TN_ATTN_V, "v", il, "bias"), false); @@ -1904,7 +1981,7 @@ struct clip_model_loader { layer.ln_1_b = get_tensor(string_format(TN_LN_1, "v", il, "bias"), false); layer.ln_2_b = get_tensor(string_format(TN_LN_2, "v", il, "bias"), false); - // new naming + // ffn layer.ff_up_w = get_tensor(string_format(TN_FFN_UP, "v", il, "weight")); layer.ff_up_b = get_tensor(string_format(TN_FFN_UP, "v", il, "bias"), false); layer.ff_gate_w = get_tensor(string_format(TN_FFN_GATE, "v", il, "weight"), false); @@ -2052,6 +2129,15 @@ struct clip_model_loader { vision_model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false); vision_model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false); } break; + case PROJECTOR_TYPE_INTERNVL: + { + vision_model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight")); + vision_model.mm_0_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "bias")); + vision_model.mm_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight")); + vision_model.mm_1_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "bias")); + vision_model.mm_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight")); + vision_model.mm_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias")); + } break; default: GGML_ASSERT(false && "unknown projector type"); } @@ -2838,7 +2924,9 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str } else if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE || ctx->proj_type == PROJECTOR_TYPE_GEMMA3 - || ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) { + || ctx->proj_type == PROJECTOR_TYPE_IDEFICS3 + || ctx->proj_type == PROJECTOR_TYPE_INTERNVL // TODO @ngxson : support dynamic resolution + ) { clip_image_u8 resized_image; int sz = params.image_size; image_manipulation::resize_and_pad_image(*img, resized_image, {sz, sz}); @@ -2988,9 +3076,13 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size); - if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2 || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) { + if (ctx->proj_type == PROJECTOR_TYPE_LDP + || ctx->proj_type == PROJECTOR_TYPE_LDPV2 + || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) { n_patches /= 4; - n_patches += 2; // for BOI and EOI token embeddings + if (ctx->vision_model.mm_glm_tok_boi) { + n_patches += 2; // for BOI and EOI token embeddings + } } else if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) { if (ctx->minicpmv_version == 2) { n_patches = 96; @@ -3013,7 +3105,8 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im int n_per_side = params.image_size / params.patch_size; int n_per_side_2d_pool = n_per_side / params.proj_scale_factor; n_patches = n_per_side_2d_pool * n_per_side_2d_pool; - } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) { + } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3 || ctx->proj_type == PROJECTOR_TYPE_INTERNVL) { + // both W and H are divided by proj_scale_factor n_patches /= (params.proj_scale_factor * params.proj_scale_factor); } else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) { int n_merge = params.spatial_merge_size; @@ -3408,6 +3501,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } break; case PROJECTOR_TYPE_GEMMA3: case PROJECTOR_TYPE_IDEFICS3: + case PROJECTOR_TYPE_INTERNVL: { // do nothing } break; @@ -3434,6 +3528,14 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima // the last node is the embedding tensor ggml_tensor * embeddings = ggml_graph_node(gf, -1); + // sanity check (only support batch size of 1 for now) + const int n_tokens_out = embeddings->ne[1]; + const int expected_n_tokens_out = clip_n_output_tokens(ctx, imgs.entries[0].get()); + if (n_tokens_out != expected_n_tokens_out) { + LOG_ERR("%s: expected %d tokens, got %d\n", __func__, expected_n_tokens_out, n_tokens_out); + GGML_ABORT("Invalid number of output tokens"); + } + // copy the embeddings to the location passed by the user ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings)); @@ -3604,6 +3706,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->vision_model.mm_input_proj_w->ne[0]; case PROJECTOR_TYPE_IDEFICS3: return ctx->vision_model.projection->ne[1]; + case PROJECTOR_TYPE_INTERNVL: + return ctx->vision_model.mm_3_w->ne[1]; default: GGML_ABORT("Unknown projector type"); } diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index 2fecf08a..f1b95739 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -252,6 +252,13 @@ int32_t mtmd_tokenize(mtmd_context * ctx, } + else if (proj_type == PROJECTOR_TYPE_INTERNVL) { + // ... (image embeddings) ... + marker_modified = "" + ctx->image_marker + ""; + string_replace_all(prompt_modified, ctx->image_marker, marker_modified); + + } + // llava-1.5, llava-1.6, Yi-VL, Yi-34B, granite: don't need to add prefix and suffix // for glm-edge, BOI and EOI token's embeddings are not present in the text model diff --git a/tools/mtmd/tests.sh b/tools/mtmd/tests.sh index 22c23749..05ac7a04 100755 --- a/tools/mtmd/tests.sh +++ b/tools/mtmd/tests.sh @@ -40,7 +40,6 @@ add_test "llama-mtmd-cli" "ggml-org/SmolVLM-500M-Instruct-GGUF:Q8_0" add_test "llama-mtmd-cli" "ggml-org/SmolVLM2-2.2B-Instruct-GGUF:Q4_K_M" add_test "llama-mtmd-cli" "ggml-org/SmolVLM2-500M-Video-Instruct-GGUF:Q8_0" add_test "llama-mtmd-cli" "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M" -add_test "llama-mtmd-cli" "guinmoon/MobileVLM-3B-GGUF:Q4_K_M" "deepseek" add_test "llama-mtmd-cli" "THUDM/glm-edge-v-5b-gguf:Q4_K_M" add_test "llama-mtmd-cli" "second-state/Llava-v1.5-7B-GGUF:Q2_K" "vicuna" add_test "llama-mtmd-cli" "cjpais/llava-1.6-mistral-7b-gguf:Q3_K" "vicuna" @@ -50,6 +49,8 @@ add_test "llama-mtmd-cli" "openbmb/MiniCPM-V-2_6-gguf:Q2_K" add_test "llama-mtmd-cli" "openbmb/MiniCPM-o-2_6-gguf:Q4_0" add_test "llama-mtmd-cli" "bartowski/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M" add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M" +add_test "llama-mtmd-cli" "ggml-org/InternVL2_5-1B-GGUF:Q8_0" +add_test "llama-mtmd-cli" "ggml-org/InternVL3-1B-Instruct-GGUF:Q8_0" # to test the big models, run: ./tests.sh big if [ "$RUN_BIG_TESTS" = true ]; then @@ -59,6 +60,8 @@ if [ "$RUN_BIG_TESTS" = true ]; then add_test "llama-mtmd-cli" "ggml-org/Qwen2-VL-7B-Instruct-GGUF:Q4_K_M" add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M" add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-7B-Instruct-GGUF:Q4_K_M" + add_test "llama-mtmd-cli" "ggml-org/InternVL3-8B-Instruct-GGUF:Q4_K_M" + add_test "llama-mtmd-cli" "ggml-org/InternVL3-14B-Instruct-GGUF:Q4_K_M" # add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-32B-Instruct-GGUF:Q4_K_M" # does not work on my mac M3 Ultra # add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-72B-Instruct-GGUF:Q4_K_M" # too big fi @@ -70,6 +73,7 @@ fi # this model has broken chat template, not usable # add_test "llama-mtmd-cli" "cmp-nct/Yi-VL-6B-GGUF:Q5_K" +# add_test "llama-mtmd-cli" "guinmoon/MobileVLM-3B-GGUF:Q4_K_M" "deepseek" ###############