]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
mtmd : support Kimi VL model (#15458)
authorXuan-Son Nguyen <redacted>
Tue, 26 Aug 2025 10:54:19 +0000 (12:54 +0200)
committerGitHub <redacted>
Tue, 26 Aug 2025 10:54:19 +0000 (12:54 +0200)
* convert : fix tensor naming conflict for llama 4 vision

* convert ok

* support kimi vision model

* clean up

* fix style

* fix calc number of output tokens

* refactor resize_position_embeddings

* add test case

* rename build fn

* correct a small bug

convert_hf_to_gguf.py
gguf-py/gguf/constants.py
gguf-py/gguf/tensor_mapping.py
tools/mtmd/clip-impl.h
tools/mtmd/clip.cpp
tools/mtmd/tests.sh

index 9fa35e8b1157321c1043ac10f925b290b5d33379..31a11cbec0baafcc69a9a6a9d2660cbce4d4bd01 100755 (executable)
@@ -6254,9 +6254,11 @@ class DeepseekModel(TextModel):
                 raise ValueError(f"Unprocessed experts: {experts}")
 
 
-@ModelBase.register("DeepseekV2ForCausalLM")
-@ModelBase.register("DeepseekV3ForCausalLM")
-@ModelBase.register("KimiVLForConditionalGeneration")
+@ModelBase.register(
+    "DeepseekV2ForCausalLM",
+    "DeepseekV3ForCausalLM",
+    "KimiVLForConditionalGeneration",
+)
 class DeepseekV2Model(TextModel):
     model_arch = gguf.MODEL_ARCH.DEEPSEEK2
 
@@ -8507,6 +8509,43 @@ class PixtralModel(LlavaVisionModel):
             return "mm.2.weight"
         return super().map_tensor_name(name, try_suffixes)
 
+
+@ModelBase.register("KimiVLForConditionalGeneration")
+class KimiVLModel(MmprojModel):
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        assert self.hparams_vision is not None
+        self.hparams_vision["image_size"] = 64 * 14 # for compatibility
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.KIMIVL)
+        self.gguf_writer.add_vision_use_gelu(True)
+        self.gguf_writer.add_vision_projector_scale_factor(2)
+        # eps is the same as pytorch's default value
+        assert self.hparams_vision is not None
+        self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("layer_norm_eps", 1e-5))
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        del bid  # unused
+        is_vision_tensor = "vision_tower" in name or "multi_modal_projector" in name
+
+        if is_vision_tensor:
+            if "pos_emb.weight" in name:
+                data_torch = data_torch.view(data_torch.shape[0] * data_torch.shape[1], data_torch.shape[2])
+            elif "wqkv" in name:
+                split_dim = 0 if "weight" in name else -1
+                wq, wk, wv = data_torch.chunk(3, dim=split_dim)
+                return [
+                    (self.map_tensor_name(name.replace("wqkv", "wq")), wq),
+                    (self.map_tensor_name(name.replace("wqkv", "wk")), wk),
+                    (self.map_tensor_name(name.replace("wqkv", "wv")), wv)
+                ]
+
+            return [(self.map_tensor_name(name), data_torch)]
+
+        return [] # skip other tensors
+
 ###### CONVERSION LOGIC ######
 
 
index d03a02c7bf9218568a1dc626d533f9c364df6d45..b9d1235d1706d1985737081ef53b03cc168198e4 100644 (file)
@@ -2850,6 +2850,7 @@ class VisionProjectorType:
     QWEN25O = "qwen2.5o" # omni
     VOXTRAL = "voxtral"
     LFM2 = "lfm2"
+    KIMIVL = "kimivl"
 
 
 # Items here are (block size, type size)
index 38bbd6e3102be358013c075e823bb4e1f6f83d86..abb21fa8219d312ff0b2675a67db77c7431e38f5 100644 (file)
@@ -1122,6 +1122,7 @@ class TensorNameMap:
             "vision_encoder.patch_conv", # pixtral
             "vision_model.patch_embedding.linear", # llama 4
             "visual.patch_embed.proj", # qwen2vl
+            "vision_tower.patch_embed.proj", # kimi-vl
         ),
 
         MODEL_TENSOR.V_ENC_EMBD_POS: (
@@ -1130,6 +1131,7 @@ class TensorNameMap:
             "vpm.embeddings.position_embedding",
             "model.vision_model.embeddings.position_embedding", # SmolVLM
             "vision_model.positional_embedding_vlm", # llama 4
+            "vision_tower.patch_embed.pos_emb", # kimi-vl
         ),
 
         MODEL_TENSOR.V_ENC_ATTN_Q: (
@@ -1141,6 +1143,7 @@ class TensorNameMap:
             "vision_tower.transformer.layers.{bid}.attention.q_proj", # pixtral-hf
             "vision_encoder.transformer.layers.{bid}.attention.wq", # pixtral
             "visual.blocks.{bid}.attn.q", # qwen2vl, generated
+            "vision_tower.encoder.blocks.{bid}.wq", # kimi-vl, generated
         ),
 
         MODEL_TENSOR.V_ENC_ATTN_Q_NORM: (
@@ -1157,6 +1160,7 @@ class TensorNameMap:
             "vision_tower.transformer.layers.{bid}.attention.k_proj", # pixtral-hf
             "vision_encoder.transformer.layers.{bid}.attention.wk", # pixtral
             "visual.blocks.{bid}.attn.k", # qwen2vl, generated
+            "vision_tower.encoder.blocks.{bid}.wk", # kimi-vl, generated
         ),
 
         MODEL_TENSOR.V_ENC_ATTN_K_NORM: (
@@ -1173,6 +1177,7 @@ class TensorNameMap:
             "vision_tower.transformer.layers.{bid}.attention.v_proj", # pixtral-hf
             "vision_encoder.transformer.layers.{bid}.attention.wv", # pixtral
             "visual.blocks.{bid}.attn.v", # qwen2vl, generated
+            "vision_tower.encoder.blocks.{bid}.wv", # kimi-vl, generated
         ),
 
         MODEL_TENSOR.V_ENC_INPUT_NORM: (
@@ -1185,6 +1190,7 @@ class TensorNameMap:
             "vision_encoder.transformer.layers.{bid}.attention_norm", # pixtral
             "vision_model.model.layers.{bid}.input_layernorm", # llama4
             "visual.blocks.{bid}.norm1", # qwen2vl
+            "vision_tower.encoder.blocks.{bid}.norm0", # kimi-vl (norm0/norm1)
         ),
 
         MODEL_TENSOR.V_ENC_ATTN_O: (
@@ -1197,6 +1203,7 @@ class TensorNameMap:
             "vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral-hf
             "vision_encoder.transformer.layers.{bid}.attention.wo", # pixtral
             "visual.blocks.{bid}.attn.proj", # qwen2vl
+            "vision_tower.encoder.blocks.{bid}.wo", # kimi-vl
         ),
 
         MODEL_TENSOR.V_ENC_POST_ATTN_NORM: (
@@ -1209,6 +1216,7 @@ class TensorNameMap:
             "vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral-hf
             "vision_encoder.transformer.layers.{bid}.ffn_norm", # pixtral
             "visual.blocks.{bid}.norm2", # qwen2vl
+            "vision_tower.encoder.blocks.{bid}.norm1", # kimi-vl (norm0/norm1)
         ),
 
         MODEL_TENSOR.V_ENC_FFN_UP: (
@@ -1221,6 +1229,7 @@ class TensorNameMap:
             "vision_model.model.layers.{bid}.mlp.fc1", # llama4
             "visual.blocks.{bid}.mlp.fc1", # qwen2vl
             "visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl
+            "vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1)
         ),
 
         MODEL_TENSOR.V_ENC_FFN_GATE: (
@@ -1239,6 +1248,7 @@ class TensorNameMap:
             "vision_model.model.layers.{bid}.mlp.fc2", # llama4
             "visual.blocks.{bid}.mlp.fc2", # qwen2vl
             "visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl
+            "vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1)
         ),
 
         MODEL_TENSOR.V_LAYER_SCALE_1: (
@@ -1263,6 +1273,7 @@ class TensorNameMap:
             "model.vision_model.post_layernorm", # SmolVLM
             "vision_model.layernorm_post", # llama4
             "visual.merger.ln_q", # qwen2vl
+            "vision_tower.encoder.final_layernorm", # kimi-vl
         ),
 
         MODEL_TENSOR.V_MM_INP_PROJ: (
@@ -1272,6 +1283,7 @@ class TensorNameMap:
         MODEL_TENSOR.V_MM_INP_NORM: (
             "multi_modal_projector.norm",
             "multi_modal_projector.layer_norm",
+            "multi_modal_projector.pre_norm",
             "pre_mm_projector_norm",
         ),
 
index 706ed2e3b5e21b72625ae7c7e27ba9b289bfdaa0..664b0c9ac6e36c6548335b9a041815ea8a066fa4 100644 (file)
@@ -135,6 +135,7 @@ enum projector_type {
     PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx
     PROJECTOR_TYPE_VOXTRAL,
     PROJECTOR_TYPE_LFM2,
+    PROJECTOR_TYPE_KIMIVL,
     PROJECTOR_TYPE_UNKNOWN,
 };
 
@@ -156,6 +157,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
     { PROJECTOR_TYPE_QWEN25O,   "qwen2.5o"},
     { PROJECTOR_TYPE_VOXTRAL,   "voxtral"},
     { PROJECTOR_TYPE_LFM2,      "lfm2"},
+    { PROJECTOR_TYPE_KIMIVL,    "kimivl"},
 };
 
 static projector_type clip_projector_type_from_string(const std::string & str) {
index 0e76b9c590bcef959a051135776c0873a90a871f..e7c516d2de8d121c046fde2b32b595f7714e8e7d 100644 (file)
@@ -526,57 +526,16 @@ struct clip_graph {
                 cur);
 
         } else if (ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3) {
+            // pixel_shuffle
             // https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578
-
             const int scale_factor = model.hparams.proj_scale_factor;
-            const int n_embd = cur->ne[0];
-            const int seq    = cur->ne[1];
-            const int bsz    = 1; // batch size, always 1 for now since we don't support batching
-            const int height = std::sqrt(seq);
-            const int width  = std::sqrt(seq);
-            GGML_ASSERT(scale_factor != 0);
-            cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height, bsz);
-            cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
-            cur = ggml_cont_4d(ctx0, cur,
-                n_embd * scale_factor * scale_factor,
-                height / scale_factor,
-                width / scale_factor,
-                bsz);
-            cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
-            cur = ggml_cont_3d(ctx0, cur,
-                n_embd * scale_factor * scale_factor,
-                seq / (scale_factor * scale_factor),
-                bsz);
-
+            cur = build_patch_merge_permute(cur, scale_factor);
             cur = ggml_mul_mat(ctx0, model.projection, cur);
+
         } else if (ctx->proj_type() == PROJECTOR_TYPE_LFM2) {
             // pixel unshuffle block
             const int scale_factor = model.hparams.proj_scale_factor;
-            GGML_ASSERT(scale_factor > 1);
-
-            const int n_embd = cur->ne[0];
-            int width  = img.nx / patch_size;
-            int height = img.ny / patch_size;
-
-            // pad width and height to factor
-            const int64_t pad_width = CLIP_ALIGN(width, scale_factor) - width;
-            const int64_t pad_height = CLIP_ALIGN(height, scale_factor) - height;
-            cur = ggml_reshape_3d(ctx0, cur, n_embd, width, height);
-            if (pad_width || pad_height) {
-                cur     = ggml_pad(ctx0, cur, 0, pad_width, pad_height, 0);
-                width  += pad_width;
-                height += pad_height;
-            }
-
-            // unshuffle h
-            cur = ggml_reshape_3d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height);
-            cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
-
-            // unshuffle w
-            cur = ggml_cont_3d(ctx0, cur, n_embd * scale_factor * scale_factor, height / scale_factor, width / scale_factor);
-            cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
-
-            cur = ggml_cont_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]);
+            cur = build_patch_merge_permute(cur, scale_factor);
 
             // projection
             cur = ggml_norm(ctx0, cur, 1e-5); // default nn.LayerNorm
@@ -1086,7 +1045,7 @@ struct clip_graph {
                 n_patches_x / scale_factor,
                 n_patches_y / scale_factor,
                 bsz);
-            cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
+            //cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
             // flatten to 2D
             cur = ggml_cont_2d(ctx0, cur,
                 n_embd * scale_factor * scale_factor,
@@ -1113,6 +1072,67 @@ struct clip_graph {
         return gf;
     }
 
+    ggml_cgraph * build_kimivl() {
+        // 2D input positions
+        ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
+        ggml_set_name(pos_h, "pos_h");
+        ggml_set_input(pos_h);
+
+        ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
+        ggml_set_name(pos_w, "pos_w");
+        ggml_set_input(pos_w);
+
+        ggml_tensor * learned_pos_embd = resize_position_embeddings();
+
+        // build ViT with 2D position embeddings
+        auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
+            // first half is X axis and second half is Y axis
+            return build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false);
+        };
+
+        ggml_tensor * inp = build_inp();
+        ggml_tensor * cur = build_vit(
+                                inp, n_patches,
+                                NORM_TYPE_NORMAL,
+                                hparams.ffn_op,
+                                learned_pos_embd,
+                                add_pos);
+
+        cb(cur, "vit_out", -1);
+
+        {
+            // patch_merger
+            const int scale_factor = model.hparams.proj_scale_factor;
+            cur = build_patch_merge_permute(cur, scale_factor);
+
+            // projection norm
+            int proj_inp_dim = cur->ne[0];
+            cur = ggml_view_2d(ctx0, cur,
+                n_embd, cur->ne[1] * scale_factor * scale_factor,
+                ggml_row_size(cur->type, n_embd), 0);
+            cur = ggml_norm(ctx0, cur, 1e-5); // default nn.LayerNorm
+            cur = ggml_mul(ctx0, cur, model.mm_input_norm_w);
+            cur = ggml_add(ctx0, cur, model.mm_input_norm_b);
+            cur = ggml_view_2d(ctx0, cur,
+                proj_inp_dim, cur->ne[1] / scale_factor / scale_factor,
+                ggml_row_size(cur->type, proj_inp_dim), 0);
+            cb(cur, "proj_inp_normed", -1);
+
+            // projection mlp
+            cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
+            cur = ggml_add(ctx0, cur, model.mm_1_b);
+            cur = ggml_gelu(ctx0, cur);
+            cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
+            cur = ggml_add(ctx0, cur, model.mm_2_b);
+            cb(cur, "proj_out", -1);
+        }
+
+        // build the graph
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
     // this graph is used by llava, granite and glm
     // due to having embedding_stack (used by granite), we cannot reuse build_vit
     ggml_cgraph * build_llava() {
@@ -1611,18 +1631,20 @@ private:
         ggml_tensor * pos_embd = model.position_embeddings;
         const int height       = img.ny / patch_size;
         const int width        = img.nx / patch_size;
+        const uint32_t mode    = GGML_SCALE_MODE_BILINEAR;
+        const int n_per_side   = (int)std::sqrt(pos_embd->ne[1]);
+
+        GGML_ASSERT(pos_embd);
 
-        if (!pos_embd || height * width == pos_embd->ne[1]) {
+        if (height == n_per_side && width == n_per_side) {
             return pos_embd;
         }
 
-        const int n_pos_embd = std::sqrt(pos_embd->ne[1]);
-        pos_embd = ggml_reshape_3d(ctx0, pos_embd, n_embd, n_pos_embd, n_pos_embd);  // -> (n_embd, n_pos_embd, n_pos_embd)
-        pos_embd = ggml_permute(ctx0, pos_embd, 2, 0, 1, 3);                         // -> (n_pos_embd, n_pos_embd, n_embd)
-        pos_embd = ggml_interpolate(ctx0, pos_embd, width, height, n_embd, 1, 1);    // -> (width, height, n_embd)
-        pos_embd = ggml_reshape_2d(ctx0, pos_embd, height * width, n_embd);          // -> (height * width, n_embd)
-        pos_embd = ggml_transpose(ctx0, pos_embd);                                   // -> (n_embd, height * width)
-        pos_embd = ggml_cont(ctx0, pos_embd);
+        pos_embd = ggml_reshape_3d(ctx0, pos_embd, n_embd, n_per_side, n_per_side);  // -> (n_embd, n_per_side, n_per_side)
+        pos_embd = ggml_permute(ctx0, pos_embd, 2, 0, 1, 3);                         // -> (n_per_side, n_per_side, n_embd)
+        pos_embd = ggml_interpolate(ctx0, pos_embd, width, height, n_embd, 1, mode); // -> (width, height, n_embd)
+        pos_embd = ggml_permute(ctx0, pos_embd, 1, 2, 0, 3);                         // -> (n_embd, width, height)
+        pos_embd = ggml_cont_2d(ctx0, pos_embd, n_embd, width * height);             // -> (n_embd, width * height)
 
         return pos_embd;
     }
@@ -2021,6 +2043,39 @@ private:
         return cur;
     }
 
+    // aka pixel_shuffle / pixel_unshuffle / patch_merger (Kimi-VL)
+    // support dynamic resolution
+    ggml_tensor * build_patch_merge_permute(ggml_tensor * cur, int scale_factor) {
+        GGML_ASSERT(scale_factor > 1);
+
+        const int n_embd = cur->ne[0];
+        int width  = img.nx / patch_size;
+        int height = img.ny / patch_size;
+
+        // pad width and height to factor
+        const int64_t pad_width  = CLIP_ALIGN(width,  scale_factor) - width;
+        const int64_t pad_height = CLIP_ALIGN(height, scale_factor) - height;
+        cur = ggml_reshape_3d(ctx0, cur, n_embd, width, height);
+        if (pad_width || pad_height) {
+            cur     = ggml_pad(ctx0, cur, 0, pad_width, pad_height, 0);
+            width  += pad_width;
+            height += pad_height;
+        }
+
+        // unshuffle h
+        cur = ggml_reshape_3d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height);
+        cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
+
+        // unshuffle w
+        cur = ggml_cont_3d(ctx0, cur, n_embd * scale_factor * scale_factor, height / scale_factor, width / scale_factor);
+        cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
+
+        cur = ggml_cont_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]);
+        cb(cur, "pixel_shuffle", -1);
+
+        return cur;
+    }
+
 };
 
 static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
@@ -2063,6 +2118,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
             {
                 res = graph.build_whisper_enc();
             } break;
+        case PROJECTOR_TYPE_KIMIVL:
+            {
+                res = graph.build_kimivl();
+            } break;
         default:
             {
                 res = graph.build_llava();
@@ -2313,6 +2372,12 @@ struct clip_model_loader {
                         hparams.image_size = 1024;
                         get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false);
                     } break;
+                case PROJECTOR_TYPE_KIMIVL:
+                    {
+                        hparams.rope_theta = 10000.0f;
+                        hparams.warmup_image_size = hparams.patch_size * 8;
+                        get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
+                    } break;
                 case PROJECTOR_TYPE_GEMMA3:
                     {
                         // default value (used by all model sizes in gemma 3 family)
@@ -2477,7 +2542,20 @@ struct clip_model_loader {
 
             // some models already exported with legacy (incorrect) naming which is quite messy, let's fix it here
             // note: Qwen model converted from the old surgery script has n_ff = 0, so we cannot use n_ff to check!
-            if (layer.ff_up_w && layer.ff_down_w && layer.ff_down_w->ne[0] == hparams.n_embd) {
+            bool is_ffn_swapped = (
+                    // only old models need this fix
+                    model.proj_type == PROJECTOR_TYPE_MLP
+                    || model.proj_type == PROJECTOR_TYPE_MLP_NORM
+                    || model.proj_type == PROJECTOR_TYPE_LDP
+                    || model.proj_type == PROJECTOR_TYPE_LDPV2
+                    || model.proj_type == PROJECTOR_TYPE_QWEN2VL
+                    || model.proj_type == PROJECTOR_TYPE_QWEN25VL
+                    || model.proj_type == PROJECTOR_TYPE_GLM_EDGE
+                    || model.proj_type == PROJECTOR_TYPE_GEMMA3
+                    || model.proj_type == PROJECTOR_TYPE_IDEFICS3
+                    || model.proj_type == PROJECTOR_TYPE_MINICPMV
+                ) && layer.ff_up_w && layer.ff_down_w && layer.ff_down_w->ne[0] == hparams.n_embd;
+            if (is_ffn_swapped) {
                 // swap up and down weights
                 ggml_tensor * tmp = layer.ff_up_w;
                 layer.ff_up_w = layer.ff_down_w;
@@ -2486,6 +2564,9 @@ struct clip_model_loader {
                 tmp = layer.ff_up_b;
                 layer.ff_up_b = layer.ff_down_b;
                 layer.ff_down_b = tmp;
+                if (il == 0) {
+                    LOG_WRN("%s: ffn up/down are swapped\n", __func__);
+                }
             }
         }
 
@@ -2604,6 +2685,7 @@ struct clip_model_loader {
                     model.projection = get_tensor(TN_MM_PROJECTOR);
                 } break;
             case PROJECTOR_TYPE_LFM2:
+            case PROJECTOR_TYPE_KIMIVL:
                 {
                     model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM);
                     model.mm_input_norm_b = get_tensor(TN_MM_INP_NORM_B);
@@ -3507,7 +3589,9 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
         res_imgs->grid_y = inst.grid_size.height;
         return true;
 
-    } else if (ctx->proj_type() == PROJECTOR_TYPE_LFM2) {
+    } else if ( ctx->proj_type() == PROJECTOR_TYPE_LFM2
+             || ctx->proj_type() == PROJECTOR_TYPE_KIMIVL
+    ) {
         GGML_ASSERT(params.proj_scale_factor);
 
         // smart resize
@@ -3708,12 +3792,21 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
         case PROJECTOR_TYPE_IDEFICS3:
         case PROJECTOR_TYPE_INTERNVL:
         case PROJECTOR_TYPE_LLAMA4:
-        case PROJECTOR_TYPE_LFM2:
             {
-                // both W and H are divided by proj_scale_factor
+                // both X and Y are downscaled by the scale factor
                 int scale_factor = ctx->model.hparams.proj_scale_factor;
                 n_patches /= (scale_factor * scale_factor);
             } break;
+        case PROJECTOR_TYPE_LFM2:
+        case PROJECTOR_TYPE_KIMIVL:
+            {
+                // dynamic size
+                int scale_factor = ctx->model.hparams.proj_scale_factor;
+                int out_patch_size = params.patch_size * scale_factor;
+                int x_patch = CLIP_ALIGN(img->nx, out_patch_size) / out_patch_size;
+                int y_patch = CLIP_ALIGN(img->ny, out_patch_size) / out_patch_size;
+                n_patches = x_patch * y_patch;
+            } break;
         case PROJECTOR_TYPE_PIXTRAL:
             {
                 // dynamic size
@@ -4096,6 +4189,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
                 set_input_i32("positions", positions);
             } break;
         case PROJECTOR_TYPE_PIXTRAL:
+        case PROJECTOR_TYPE_KIMIVL:
             {
                 // set the 2D positions
                 int n_patches_per_col = image_size_width / patch_size;
@@ -4250,6 +4344,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
         case PROJECTOR_TYPE_QWEN2A:
             return ctx->model.mm_fc_w->ne[1];
         case PROJECTOR_TYPE_LFM2:
+        case PROJECTOR_TYPE_KIMIVL:
             return ctx->model.mm_2_w->ne[1];
         default:
             GGML_ABORT("Unknown projector type");
index 6f8a5f86ac5b2254721fd6a79843c53992442805..c64be03630a5636c1e98935438ddbcfee708eb78 100755 (executable)
@@ -86,6 +86,7 @@ if [ "$RUN_BIG_TESTS" = true ]; then
     add_test_vision "ggml-org/InternVL3-14B-Instruct-GGUF:Q4_K_M"
     add_test_vision "ggml-org/Qwen2.5-Omni-7B-GGUF:Q4_K_M"
     # add_test_vision "ggml-org/Qwen2.5-VL-32B-Instruct-GGUF:Q4_K_M" # does not work on my mac M3 Ultra
+    add_test_vision "ggml-org/Kimi-VL-A3B-Thinking-2506-GGUF:Q4_K_M"
 
     add_test_audio  "ggml-org/ultravox-v0_5-llama-3_1-8b-GGUF:Q4_K_M"
     add_test_audio  "ggml-org/Qwen2.5-Omni-7B-GGUF:Q4_K_M"