]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
mtmd : support InternVL 2.5 and 3 (#13422)
authorXuan-Son Nguyen <redacted>
Sat, 10 May 2025 14:26:42 +0000 (16:26 +0200)
committerGitHub <redacted>
Sat, 10 May 2025 14:26:42 +0000 (16:26 +0200)
* convert : internvl support

* InternVL3-1B working

* fix regression

* rm mobilevlm from test

* fix conversion

* add test for internvl

* add to list of pre-quant

* restore boi/eoi check

* add clarify comment for norm eps

convert_hf_to_gguf.py
docs/multimodal.md
gguf-py/gguf/constants.py
gguf-py/gguf/tensor_mapping.py
tools/mtmd/README.md
tools/mtmd/clip-impl.h
tools/mtmd/clip.cpp
tools/mtmd/mtmd.cpp
tools/mtmd/tests.sh

index bf6bc68380b19a5c70ede7fb0f2730b021b5f21d..e5c397fee2c308c2f26412d160c66895e667a624 100755 (executable)
@@ -426,7 +426,11 @@ class ModelBase:
             logger.warning(f"Failed to load model config from {dir_model}: {e}")
             logger.warning("Trying to load config.json instead")
             with open(dir_model / "config.json", "r", encoding="utf-8") as f:
-                return json.load(f)
+                config = json.load(f)
+                if "llm_config" in config:
+                    # rename for InternVL
+                    config["text_config"] = config["llm_config"]
+                return config
 
     @classmethod
     def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
@@ -2606,6 +2610,11 @@ class Qwen2Model(TextModel):
     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
         if self.hf_arch == "Qwen2Model":
             name = f"model.{name}"  # map to Qwen2ForCausalLM tensors
+        if "language_model." in name:
+            name = name.replace("language_model.", "") # for InternVL
+        if name.startswith("mlp") or name.startswith("vision_model"):
+            # skip visual tensors
+            return []
         yield from super().modify_tensors(data_torch, name, bid)
 
 
@@ -2709,6 +2718,62 @@ class Qwen2VLVisionModel(VisionModel):
         return [] # skip other tensors
 
 
+@ModelBase.register("InternVisionModel")
+class InternVisionModel(VisionModel):
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        hparams = self.hparams
+        self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.INTERNVL)
+        self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
+        # hidden_act
+        if hparams["hidden_act"] == "silu":
+            self.gguf_writer.add_vision_use_silu(True)
+        elif hparams["hidden_act"] == "gelu":
+            self.gguf_writer.add_vision_use_gelu(True)
+        else:
+            raise ValueError(f"Unsupported hidden_act: {hparams['hidden_act']}")
+        # downsample_ratio
+        downsample_ratio = self.global_config.get("downsample_ratio")
+        assert downsample_ratio is not None
+        self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / downsample_ratio))
+
+    def tensor_force_quant(self, name, new_name, bid, n_dims):
+        del bid, name, n_dims  # unused
+        if ".patch_embd." in new_name:
+            return gguf.GGMLQuantizationType.F16
+        if ".position_embd." in new_name:
+            return gguf.GGMLQuantizationType.F32
+        return False
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        del bid  # unused
+        if name.startswith("vision_model") or name.startswith("mlp"):
+            # process visual tensors
+            # correct name
+            if name.startswith("vision_model"):
+                name = "vision_tower." + name
+            if (".ls" in name or "position_embedding" in name) and not name.endswith(".weight"):
+                name += ".weight"
+            # split QKV tensors if needed
+            if ".qkv." in name:
+                if data_torch.ndim == 2: # weight
+                    c3, _ = data_torch.shape
+                else: # bias
+                    c3 = data_torch.shape[0]
+                assert c3 % 3 == 0
+                c = c3 // 3
+                wq = data_torch[:c]
+                wk = data_torch[c: c * 2]
+                wv = data_torch[c * 2:]
+                return [
+                    (self.map_tensor_name(name.replace("attn.qkv", "self_attn.q_proj")), wq),
+                    (self.map_tensor_name(name.replace("attn.qkv", "self_attn.k_proj")), wk),
+                    (self.map_tensor_name(name.replace("attn.qkv", "self_attn.v_proj")), wv),
+                ]
+            return [(self.map_tensor_name(name), data_torch)]
+        return [] # skip other tensors
+
+
 @ModelBase.register("WavTokenizerDec")
 class WavTokenizerDecModel(TextModel):
     model_arch = gguf.MODEL_ARCH.WAVTOKENIZER_DEC
@@ -3360,6 +3425,11 @@ class InternLM2Model(TextModel):
         head_dim = n_embd // num_heads
         num_groups = num_heads // q_per_kv
 
+        name = name.replace("language_model.", "") # InternVL
+        if name.startswith("mlp") or name.startswith("vision_model"):
+            # skip visual tensors
+            return []
+
         if bid is not None and f"model.layers.{bid}.attention.wqkv" in name:
             qkv = data_torch
 
@@ -3433,6 +3503,10 @@ class InternLM3Model(TextModel):
     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
         n_head = self.hparams["num_attention_heads"]
         n_kv_head = self.hparams.get("num_key_value_heads")
+        name = name.replace("language_model.", "") # InternVL
+        if name.startswith("mlp") or name.startswith("vision_model"):
+            # skip visual tensors
+            return []
         if name.endswith(("q_proj.weight", "q_proj.bias")):
             data_torch = LlamaModel.permute(data_torch, n_head, n_head)
         if name.endswith(("k_proj.weight", "k_proj.bias")):
index efed473a3cd0707b46b2a239905a1b9edcc2b2e0..090583f9797581651fb3c769e187924865235666 100644 (file)
@@ -66,4 +66,12 @@ NOTE: some models may require large context window, for example: `-c 8192`
 
 # Mistral Small 3.1 24B (IQ2_M quantization)
 (tool_name) -hf ggml-org/Mistral-Small-3.1-24B-Instruct-2503-GGUF
+
+# InternVL 2.5 and 3
+(tool_name) -hf ggml-org/InternVL2_5-1B-GGUF
+(tool_name) -hf ggml-org/InternVL2_5-2B-GGUF
+(tool_name) -hf ggml-org/InternVL3-1B-Instruct-GGUF
+(tool_name) -hf ggml-org/InternVL3-2B-Instruct-GGUF
+(tool_name) -hf ggml-org/InternVL3-4B-Instruct-GGUF
+(tool_name) -hf ggml-org/InternVL3-14B-Instruct-GGUF
 ```
index 7dd7bb6d1b5d9392d4155dc7da0a8039bc193a01..ae5ce71aef939bc327c0c55ffe2716a74d72d1fb 100644 (file)
@@ -491,6 +491,8 @@ class MODEL_TENSOR(IntEnum):
     V_ENC_FFN_UP         = auto()
     V_ENC_FFN_GATE       = auto()
     V_ENC_FFN_DOWN       = auto()
+    V_LAYER_SCALE_1      = auto()
+    V_LAYER_SCALE_2      = auto()
     V_PRE_NORM           = auto()
     V_POST_NORM          = auto()
     V_MM_INP_NORM        = auto()
@@ -748,6 +750,8 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
     MODEL_TENSOR.V_ENC_FFN_UP:              "v.blk.{bid}.ffn_up",
     MODEL_TENSOR.V_ENC_FFN_GATE:            "v.blk.{bid}.ffn_gate",
     MODEL_TENSOR.V_ENC_FFN_DOWN:            "v.blk.{bid}.ffn_down",
+    MODEL_TENSOR.V_LAYER_SCALE_1:           "v.blk.{bid}.ls1",
+    MODEL_TENSOR.V_LAYER_SCALE_2:           "v.blk.{bid}.ls2",
     MODEL_TENSOR.V_PRE_NORM:                "v.pre_ln",
     MODEL_TENSOR.V_POST_NORM:               "v.post_ln",
     MODEL_TENSOR.V_MM_INP_PROJ:             "mm.input_projection",
@@ -786,6 +790,8 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.V_ENC_FFN_UP,
         MODEL_TENSOR.V_ENC_FFN_GATE,
         MODEL_TENSOR.V_ENC_FFN_DOWN,
+        MODEL_TENSOR.V_LAYER_SCALE_1,
+        MODEL_TENSOR.V_LAYER_SCALE_2,
         MODEL_TENSOR.V_PRE_NORM,
         MODEL_TENSOR.V_POST_NORM,
         MODEL_TENSOR.V_MM_INP_PROJ,
@@ -2167,6 +2173,7 @@ class VisionProjectorType:
     PIXTRAL = "pixtral"
     QWEN2VL = "qwen2vl_merger"
     QWEN25VL = "qwen2.5vl_merger"
+    INTERNVL = "internvl"
 
 
 # Items here are (block size, type size)
index 003b0172c77b07a76d4f9de04e11abce16d6fc79..bf7ec325772ce055bf9338dc6ab21edd665eb809 100644 (file)
@@ -905,6 +905,7 @@ class TensorNameMap:
 
         MODEL_TENSOR.V_MMPROJ_MLP: (
             "model.mm_projector.mlp.mlp.{bid}",
+            "mlp1.{bid}", # InternVL
         ),
 
         MODEL_TENSOR.V_MMPROJ_PEG: (
@@ -955,6 +956,7 @@ class TensorNameMap:
 
         MODEL_TENSOR.V_ENC_INPUT_NORM: (
             "vision_tower.vision_model.encoder.layers.{bid}.layer_norm1",
+            "vision_tower.vision_model.encoder.layers.{bid}.norm1", # InternVL
             "vpm.encoder.layers.{bid}.layer_norm1",
             "model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM
             "vision_tower.transformer.layers.{bid}.attention_norm", # pixtral
@@ -963,6 +965,7 @@ class TensorNameMap:
 
         MODEL_TENSOR.V_ENC_OUTPUT: (
             "vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj",
+            "vision_tower.vision_model.encoder.layers.{bid}.attn.proj", # InternVL
             "vpm.encoder.layers.{bid}.self_attn.out_proj",
             "model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM
             "vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral
@@ -971,6 +974,7 @@ class TensorNameMap:
 
         MODEL_TENSOR.V_ENC_OUTPUT_NORM: (
             "vision_tower.vision_model.encoder.layers.{bid}.layer_norm2",
+            "vision_tower.vision_model.encoder.layers.{bid}.norm2", # InternVL
             "vpm.encoder.layers.{bid}.layer_norm2",
             "model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM
             "vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral
@@ -1000,6 +1004,14 @@ class TensorNameMap:
             "visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl
         ),
 
+        MODEL_TENSOR.V_LAYER_SCALE_1: (
+            "vision_tower.vision_model.encoder.layers.{bid}.ls1", # InternVL
+        ),
+
+        MODEL_TENSOR.V_LAYER_SCALE_2: (
+            "vision_tower.vision_model.encoder.layers.{bid}.ls2", # InternVL
+        ),
+
         MODEL_TENSOR.V_PRE_NORM: (
             "vision_tower.vision_model.pre_layrnorm",
             "vision_tower.ln_pre", # pixtral
index 06e1fd097423a9ba54fb929f9cf0fe5a443f599c..ab258ea174e84c23950b60962023e2491fff0ef3 100644 (file)
@@ -48,6 +48,7 @@ For the following models, you can use `convert_hf_to_gguf.py`with `--mmproj` fla
 - [Pixtral 12B](https://huggingface.co/mistral-community/pixtral-12b) - only works with `transformers`-compatible checkpoint
 - Qwen 2 VL and Qwen 2.5 VL (from [Qwen](https://huggingface.co/Qwen))
 - [Mistral Small 3.1 24B](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503)
+- InternVL 2.5 and InternVL 3 from [OpenGVLab](https://huggingface.co/OpenGVLab) (note: we don't support conversion of `InternVL3-*-hf` model, only non-HF version is supported ; `InternLM2Model` **text** model is not supported)
 
 For older models, please refer to the relevant guide for instructions on how to obtain or create them:
 
index fb780e9deac7ee679bf4069d0c8b6d821a45f464..e9c8646e1b4491c1827cd4f6ca0b18560588f8da 100644 (file)
@@ -33,9 +33,6 @@
 #define KEY_PROJ_TYPE           "clip.projector_type"
 #define KEY_SPATIAL_MERGE_SIZE  "clip.vision.spatial_merge_size"
 
-#define KEY_USE_GLU_MLP         "clip.use_glu_mlp"  // for qwen2.5vl
-#define KEY_USE_RMS_NORM        "clip.use_rms_norm" // for qwen2.5vl
-
 #define KEY_MM_PATCH_MERGE_TYPE   "clip.vision.mm_patch_merge_type"
 #define KEY_IMAGE_GRID_PINPOINTS  "clip.vision.image_grid_pinpoints"
 #define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution"
 #define TN_FFN_GATE        "%s.blk.%d.ffn_gate.%s"
 #define TN_FFN_UP          "%s.blk.%d.ffn_up.%s"
 #define TN_FFN_GATE        "%s.blk.%d.ffn_gate.%s"
-#define TN_LN_1            "%s.blk.%d.ln1.%s"
-#define TN_LN_2            "%s.blk.%d.ln2.%s"
+#define TN_LN_1            "%s.blk.%d.ln1.%s" // layer norm
+#define TN_LN_2            "%s.blk.%d.ln2.%s" // layer norm
+#define TN_LS_1            "%s.blk.%d.ls1.%s" // layer scale
+#define TN_LS_2            "%s.blk.%d.ls2.%s" // layer scale
 #define TN_LN_PRE          "%s.pre_ln.%s"
 #define TN_LN_POST         "%s.post_ln.%s"
 #define TN_LLAVA_PROJ      "mm.%d.%s"
@@ -105,6 +104,7 @@ enum projector_type {
     PROJECTOR_TYPE_IDEFICS3,
     PROJECTOR_TYPE_PIXTRAL,
     PROJECTOR_TYPE_QWEN25VL,
+    PROJECTOR_TYPE_INTERNVL,
     PROJECTOR_TYPE_UNKNOWN,
 };
 
@@ -119,6 +119,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
     { PROJECTOR_TYPE_GEMMA3,    "gemma3"},
     { PROJECTOR_TYPE_IDEFICS3,  "idefics3"},
     { PROJECTOR_TYPE_PIXTRAL,   "pixtral"},
+    { PROJECTOR_TYPE_INTERNVL,  "internvl"},
 };
 
 static projector_type clip_projector_type_from_string(const std::string & str) {
index 1a81c1fcdf60e12c1daba87756ae9d50eb90dc77..dfe7ac91c48b5e5df2873300c9a0b8dae05a732a 100644 (file)
@@ -215,6 +215,10 @@ struct clip_layer {
     // layernorm 2
     ggml_tensor * ln_2_w = nullptr;
     ggml_tensor * ln_2_b = nullptr;
+
+    // layer scale (no bias)
+    ggml_tensor * ls_1_w = nullptr;
+    ggml_tensor * ls_2_w = nullptr;
 };
 
 struct clip_vision_model {
@@ -589,6 +593,9 @@ struct clip_graph {
 
     // Qwen2VL and Qwen2.5VL use M-RoPE
     ggml_cgraph * build_qwen2vl() {
+        GGML_ASSERT(model.patch_bias == nullptr);
+        GGML_ASSERT(model.class_embedding == nullptr);
+
         const int batch_size       = 1;
         const bool use_window_attn = hparams.n_wa_pattern > 0;
         const int n_wa_pattern     = hparams.n_wa_pattern;
@@ -625,10 +632,6 @@ struct clip_graph {
                 n_embd, n_patches_x * n_patches_y, batch_size);
         }
 
-        if (model.patch_bias) {
-            inp = ggml_add(ctx0, inp, model.patch_bias);
-        }
-
         ggml_tensor * inpL           = inp;
         ggml_tensor * window_mask    = nullptr;
         ggml_tensor * window_idx     = nullptr;
@@ -859,6 +862,67 @@ struct clip_graph {
         return gf;
     }
 
+    ggml_cgraph * build_internvl() {
+        GGML_ASSERT(model.class_embedding != nullptr);
+        GGML_ASSERT(model.position_embeddings != nullptr);
+
+        const int n_pos = n_patches + 1;
+        ggml_tensor * inp = build_inp();
+
+        // add CLS token
+        inp = ggml_concat(ctx0, inp, model.class_embedding, 1);
+
+        ggml_tensor * cur = build_vit(
+                                inp, n_pos,
+                                NORM_TYPE_NORMAL,
+                                hparams.ffn_op,
+                                model.position_embeddings,
+                                nullptr);
+
+        // remove CLS token
+        cur = ggml_view_2d(ctx0, cur,
+            n_embd, n_patches,
+            ggml_row_size(cur->type, n_embd), 0);
+
+        // pixel shuffle
+        {
+            const int scale_factor = model.hparams.proj_scale_factor;
+            const int bsz    = 1; // batch size, always 1 for now since we don't support batching
+            const int height = n_patches_y;
+            const int width  = n_patches_x;
+            GGML_ASSERT(scale_factor > 0);
+            cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, height / scale_factor, width, bsz);
+            cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
+            cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur),
+                n_embd * scale_factor * scale_factor,
+                height / scale_factor,
+                width / scale_factor,
+                bsz);
+            cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
+            // flatten to 2D
+            cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, cur),
+                n_embd * scale_factor * scale_factor,
+                cur->ne[1] * cur->ne[2]);
+        }
+
+        // projector (always using GELU activation)
+        {
+            // projector LayerNorm uses pytorch's default eps = 1e-5
+            // ref: https://huggingface.co/OpenGVLab/InternVL3-8B-Instruct/blob/a34d3e4e129a5856abfd6aa6de79776484caa14e/modeling_internvl_chat.py#L79
+            cur = build_norm(cur, model.mm_0_w, model.mm_0_b, NORM_TYPE_NORMAL, 1e-5, -1);
+            cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
+            cur = ggml_add(ctx0, cur, model.mm_1_b);
+            cur = ggml_gelu(ctx0, cur);
+            cur = ggml_mul_mat(ctx0, model.mm_3_w, cur);
+            cur = ggml_add(ctx0, cur, model.mm_3_b);
+        }
+
+        // build the graph
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
     // this graph is used by llava, granite and glm
     // due to having embedding_stack (used by granite), we cannot reuse build_vit
     ggml_cgraph * build_llava() {
@@ -890,10 +954,6 @@ struct clip_graph {
 
         ggml_tensor * inp = build_inp();
 
-        if (model.patch_bias) {
-            inp = ggml_add(ctx0, inp, model.patch_bias);
-        }
-
         // concat class_embeddings and patch_embeddings
         if (model.class_embedding) {
             inp = ggml_concat(ctx0, inp, model.class_embedding, 1);
@@ -1260,11 +1320,6 @@ private:
                 ggml_tensor * learned_pos_embd,
                 std::function<ggml_tensor *(ggml_tensor *, const clip_layer &)> add_pos
             ) {
-        if (model.patch_bias) {
-            inp = ggml_add(ctx0, inp, model.patch_bias);
-            cb(inp, "patch_bias", -1);
-        }
-
         if (learned_pos_embd) {
             inp = ggml_add(ctx0, inp, learned_pos_embd);
             cb(inp, "pos_embed", -1);
@@ -1324,6 +1379,11 @@ private:
                 cb(cur, "attn_out", il);
             }
 
+            if (layer.ls_1_w) {
+                cur = ggml_mul(ctx0, cur, layer.ls_1_w);
+                cb(cur, "attn_out_scaled", il);
+            }
+
             // re-add the layer input, e.g., residual
             cur = ggml_add(ctx0, cur, inpL);
 
@@ -1344,6 +1404,11 @@ private:
 
             cb(cur, "ffn_out", il);
 
+            if (layer.ls_2_w) {
+                cur = ggml_mul(ctx0, cur, layer.ls_2_w);
+                cb(cur, "ffn_out_scaled", il);
+            }
+
             // residual 2
             cur = ggml_add(ctx0, inpL, cur);
             cb(cur, "layer_out", il);
@@ -1365,6 +1430,10 @@ private:
         ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
         inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd);
         inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
+        if (model.patch_bias) {
+            inp = ggml_add(ctx0, inp, model.patch_bias);
+            cb(inp, "patch_bias", -1);
+        }
         return inp;
     }
 
@@ -1627,6 +1696,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
             {
                 res = graph.build_minicpmv();
             } break;
+        case PROJECTOR_TYPE_INTERNVL:
+            {
+                res = graph.build_internvl();
+            } break;
         default:
             {
                 res = graph.build_llava();
@@ -1790,6 +1863,7 @@ struct clip_model_loader {
                         }
                     } break;
                 case PROJECTOR_TYPE_IDEFICS3:
+                case PROJECTOR_TYPE_INTERNVL:
                     {
                         get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
                     } break;
@@ -1897,6 +1971,9 @@ struct clip_model_loader {
             layer.o_w    = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "weight"));
             layer.ln_1_w = get_tensor(string_format(TN_LN_1,        "v", il, "weight"), false);
             layer.ln_2_w = get_tensor(string_format(TN_LN_2,        "v", il, "weight"), false);
+            layer.ls_1_w = get_tensor(string_format(TN_LS_1,        "v", il, "weight"), false); // no bias
+            layer.ls_2_w = get_tensor(string_format(TN_LS_2,        "v", il, "weight"), false); // no bias
+
             layer.k_b    = get_tensor(string_format(TN_ATTN_K,      "v", il, "bias"), false);
             layer.q_b    = get_tensor(string_format(TN_ATTN_Q,      "v", il, "bias"), false);
             layer.v_b    = get_tensor(string_format(TN_ATTN_V,      "v", il, "bias"), false);
@@ -1904,7 +1981,7 @@ struct clip_model_loader {
             layer.ln_1_b = get_tensor(string_format(TN_LN_1,        "v", il, "bias"), false);
             layer.ln_2_b = get_tensor(string_format(TN_LN_2,        "v", il, "bias"), false);
 
-            // new naming
+            // ffn
             layer.ff_up_w   = get_tensor(string_format(TN_FFN_UP,   "v", il, "weight"));
             layer.ff_up_b   = get_tensor(string_format(TN_FFN_UP,   "v", il, "bias"),   false);
             layer.ff_gate_w = get_tensor(string_format(TN_FFN_GATE, "v", il, "weight"), false);
@@ -2052,6 +2129,15 @@ struct clip_model_loader {
                     vision_model.mm_input_norm_w   = get_tensor(TN_MM_INP_NORM,     false);
                     vision_model.mm_patch_merger_w = get_tensor(TN_MM_PATCH_MERGER, false);
                 } break;
+            case PROJECTOR_TYPE_INTERNVL:
+                {
+                    vision_model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight"));
+                    vision_model.mm_0_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "bias"));
+                    vision_model.mm_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight"));
+                    vision_model.mm_1_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "bias"));
+                    vision_model.mm_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight"));
+                    vision_model.mm_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias"));
+                } break;
             default:
                 GGML_ASSERT(false && "unknown projector type");
         }
@@ -2838,7 +2924,9 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
     }
     else if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE
             || ctx->proj_type == PROJECTOR_TYPE_GEMMA3
-            || ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
+            || ctx->proj_type == PROJECTOR_TYPE_IDEFICS3
+            || ctx->proj_type == PROJECTOR_TYPE_INTERNVL // TODO @ngxson : support dynamic resolution
+    ) {
         clip_image_u8 resized_image;
         int sz = params.image_size;
         image_manipulation::resize_and_pad_image(*img, resized_image, {sz, sz});
@@ -2988,9 +3076,13 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
 
     int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
 
-    if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2 || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
+    if (ctx->proj_type == PROJECTOR_TYPE_LDP
+            || ctx->proj_type == PROJECTOR_TYPE_LDPV2
+            || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
         n_patches /= 4;
-        n_patches += 2; // for BOI and EOI token embeddings
+        if (ctx->vision_model.mm_glm_tok_boi) {
+            n_patches += 2; // for BOI and EOI token embeddings
+        }
     } else if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
         if (ctx->minicpmv_version == 2) {
             n_patches = 96;
@@ -3013,7 +3105,8 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
         int n_per_side = params.image_size / params.patch_size;
         int n_per_side_2d_pool = n_per_side / params.proj_scale_factor;
         n_patches = n_per_side_2d_pool * n_per_side_2d_pool;
-    } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
+    } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3 || ctx->proj_type == PROJECTOR_TYPE_INTERNVL) {
+        // both W and H are divided by proj_scale_factor
         n_patches /= (params.proj_scale_factor * params.proj_scale_factor);
     } else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
         int n_merge = params.spatial_merge_size;
@@ -3408,6 +3501,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
             } break;
         case PROJECTOR_TYPE_GEMMA3:
         case PROJECTOR_TYPE_IDEFICS3:
+        case PROJECTOR_TYPE_INTERNVL:
             {
                 // do nothing
             } break;
@@ -3434,6 +3528,14 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
     // the last node is the embedding tensor
     ggml_tensor * embeddings = ggml_graph_node(gf, -1);
 
+    // sanity check (only support batch size of 1 for now)
+    const int n_tokens_out = embeddings->ne[1];
+    const int expected_n_tokens_out = clip_n_output_tokens(ctx, imgs.entries[0].get());
+    if (n_tokens_out != expected_n_tokens_out) {
+        LOG_ERR("%s: expected %d tokens, got %d\n", __func__, expected_n_tokens_out, n_tokens_out);
+        GGML_ABORT("Invalid number of output tokens");
+    }
+
     // copy the embeddings to the location passed by the user
     ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings));
 
@@ -3604,6 +3706,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
             return ctx->vision_model.mm_input_proj_w->ne[0];
         case PROJECTOR_TYPE_IDEFICS3:
             return ctx->vision_model.projection->ne[1];
+        case PROJECTOR_TYPE_INTERNVL:
+            return ctx->vision_model.mm_3_w->ne[1];
         default:
             GGML_ABORT("Unknown projector type");
     }
index 2fecf08a44e94b7ed1dd53c7547cf6e670232b8c..f1b957394767148e8a959314e50473b0029459a8 100644 (file)
@@ -252,6 +252,13 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
 
     }
 
+    else if (proj_type == PROJECTOR_TYPE_INTERNVL) {
+        // <img> ... (image embeddings) ... </img>
+        marker_modified = "<img>" + ctx->image_marker + "</img>";
+        string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
+
+    }
+
     // llava-1.5, llava-1.6, Yi-VL, Yi-34B, granite: don't need to add prefix and suffix
     // for glm-edge, BOI and EOI token's embeddings are not present in the text model
 
index 22c23749713316d14e13fd206ce4d60183935740..05ac7a04d8fce36780747baa2e1bb7d3e8475b17 100755 (executable)
@@ -40,7 +40,6 @@ add_test "llama-mtmd-cli"  "ggml-org/SmolVLM-500M-Instruct-GGUF:Q8_0"
 add_test "llama-mtmd-cli"  "ggml-org/SmolVLM2-2.2B-Instruct-GGUF:Q4_K_M"
 add_test "llama-mtmd-cli"  "ggml-org/SmolVLM2-500M-Video-Instruct-GGUF:Q8_0"
 add_test "llama-mtmd-cli"  "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M"
-add_test "llama-mtmd-cli"  "guinmoon/MobileVLM-3B-GGUF:Q4_K_M"               "deepseek"
 add_test "llama-mtmd-cli"  "THUDM/glm-edge-v-5b-gguf:Q4_K_M"
 add_test "llama-mtmd-cli"  "second-state/Llava-v1.5-7B-GGUF:Q2_K"            "vicuna"
 add_test "llama-mtmd-cli"  "cjpais/llava-1.6-mistral-7b-gguf:Q3_K"           "vicuna"
@@ -50,6 +49,8 @@ add_test "llama-mtmd-cli"  "openbmb/MiniCPM-V-2_6-gguf:Q2_K"
 add_test "llama-mtmd-cli"  "openbmb/MiniCPM-o-2_6-gguf:Q4_0"
 add_test "llama-mtmd-cli"  "bartowski/Qwen2-VL-2B-Instruct-GGUF:Q4_K_M"
 add_test "llama-mtmd-cli"  "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M"
+add_test "llama-mtmd-cli"  "ggml-org/InternVL2_5-1B-GGUF:Q8_0"
+add_test "llama-mtmd-cli"  "ggml-org/InternVL3-1B-Instruct-GGUF:Q8_0"
 
 # to test the big models, run: ./tests.sh big
 if [ "$RUN_BIG_TESTS" = true ]; then
@@ -59,6 +60,8 @@ if [ "$RUN_BIG_TESTS" = true ]; then
     add_test "llama-mtmd-cli" "ggml-org/Qwen2-VL-7B-Instruct-GGUF:Q4_K_M"
     add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M"
     add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-7B-Instruct-GGUF:Q4_K_M"
+    add_test "llama-mtmd-cli"  "ggml-org/InternVL3-8B-Instruct-GGUF:Q4_K_M"
+    add_test "llama-mtmd-cli"  "ggml-org/InternVL3-14B-Instruct-GGUF:Q4_K_M"
     # add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-32B-Instruct-GGUF:Q4_K_M" # does not work on my mac M3 Ultra
     # add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-72B-Instruct-GGUF:Q4_K_M" # too big
 fi
@@ -70,6 +73,7 @@ fi
 
 # this model has broken chat template, not usable
 # add_test "llama-mtmd-cli"  "cmp-nct/Yi-VL-6B-GGUF:Q5_K"
+# add_test "llama-mtmd-cli"  "guinmoon/MobileVLM-3B-GGUF:Q4_K_M" "deepseek"
 
 ###############