From: Xuan-Son Nguyen Date: Tue, 26 Aug 2025 10:54:19 +0000 (+0200) Subject: mtmd : support Kimi VL model (#15458) X-Git-Tag: upstream/0.0.6527~242 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=79a546220c719e6a70627b243a478ab8d84dc9e1;p=pkg%2Fggml%2Fsources%2Fllama.cpp mtmd : support Kimi VL model (#15458) * convert : fix tensor naming conflict for llama 4 vision * convert ok * support kimi vision model * clean up * fix style * fix calc number of output tokens * refactor resize_position_embeddings * add test case * rename build fn * correct a small bug --- diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 9fa35e8b..31a11cbe 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -6254,9 +6254,11 @@ class DeepseekModel(TextModel): raise ValueError(f"Unprocessed experts: {experts}") -@ModelBase.register("DeepseekV2ForCausalLM") -@ModelBase.register("DeepseekV3ForCausalLM") -@ModelBase.register("KimiVLForConditionalGeneration") +@ModelBase.register( + "DeepseekV2ForCausalLM", + "DeepseekV3ForCausalLM", + "KimiVLForConditionalGeneration", +) class DeepseekV2Model(TextModel): model_arch = gguf.MODEL_ARCH.DEEPSEEK2 @@ -8507,6 +8509,43 @@ class PixtralModel(LlavaVisionModel): return "mm.2.weight" return super().map_tensor_name(name, try_suffixes) + +@ModelBase.register("KimiVLForConditionalGeneration") +class KimiVLModel(MmprojModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.hparams_vision is not None + self.hparams_vision["image_size"] = 64 * 14 # for compatibility + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.KIMIVL) + self.gguf_writer.add_vision_use_gelu(True) + self.gguf_writer.add_vision_projector_scale_factor(2) + # eps is the same as pytorch's default value + assert self.hparams_vision is not None + self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("layer_norm_eps", 1e-5)) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + is_vision_tensor = "vision_tower" in name or "multi_modal_projector" in name + + if is_vision_tensor: + if "pos_emb.weight" in name: + data_torch = data_torch.view(data_torch.shape[0] * data_torch.shape[1], data_torch.shape[2]) + elif "wqkv" in name: + split_dim = 0 if "weight" in name else -1 + wq, wk, wv = data_torch.chunk(3, dim=split_dim) + return [ + (self.map_tensor_name(name.replace("wqkv", "wq")), wq), + (self.map_tensor_name(name.replace("wqkv", "wk")), wk), + (self.map_tensor_name(name.replace("wqkv", "wv")), wv) + ] + + return [(self.map_tensor_name(name), data_torch)] + + return [] # skip other tensors + ###### CONVERSION LOGIC ###### diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index d03a02c7..b9d1235d 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -2850,6 +2850,7 @@ class VisionProjectorType: QWEN25O = "qwen2.5o" # omni VOXTRAL = "voxtral" LFM2 = "lfm2" + KIMIVL = "kimivl" # Items here are (block size, type size) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 38bbd6e3..abb21fa8 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1122,6 +1122,7 @@ class TensorNameMap: "vision_encoder.patch_conv", # pixtral "vision_model.patch_embedding.linear", # llama 4 "visual.patch_embed.proj", # qwen2vl + "vision_tower.patch_embed.proj", # kimi-vl ), MODEL_TENSOR.V_ENC_EMBD_POS: ( @@ -1130,6 +1131,7 @@ class TensorNameMap: "vpm.embeddings.position_embedding", "model.vision_model.embeddings.position_embedding", # SmolVLM "vision_model.positional_embedding_vlm", # llama 4 + "vision_tower.patch_embed.pos_emb", # kimi-vl ), MODEL_TENSOR.V_ENC_ATTN_Q: ( @@ -1141,6 +1143,7 @@ class TensorNameMap: "vision_tower.transformer.layers.{bid}.attention.q_proj", # pixtral-hf "vision_encoder.transformer.layers.{bid}.attention.wq", # pixtral "visual.blocks.{bid}.attn.q", # qwen2vl, generated + "vision_tower.encoder.blocks.{bid}.wq", # kimi-vl, generated ), MODEL_TENSOR.V_ENC_ATTN_Q_NORM: ( @@ -1157,6 +1160,7 @@ class TensorNameMap: "vision_tower.transformer.layers.{bid}.attention.k_proj", # pixtral-hf "vision_encoder.transformer.layers.{bid}.attention.wk", # pixtral "visual.blocks.{bid}.attn.k", # qwen2vl, generated + "vision_tower.encoder.blocks.{bid}.wk", # kimi-vl, generated ), MODEL_TENSOR.V_ENC_ATTN_K_NORM: ( @@ -1173,6 +1177,7 @@ class TensorNameMap: "vision_tower.transformer.layers.{bid}.attention.v_proj", # pixtral-hf "vision_encoder.transformer.layers.{bid}.attention.wv", # pixtral "visual.blocks.{bid}.attn.v", # qwen2vl, generated + "vision_tower.encoder.blocks.{bid}.wv", # kimi-vl, generated ), MODEL_TENSOR.V_ENC_INPUT_NORM: ( @@ -1185,6 +1190,7 @@ class TensorNameMap: "vision_encoder.transformer.layers.{bid}.attention_norm", # pixtral "vision_model.model.layers.{bid}.input_layernorm", # llama4 "visual.blocks.{bid}.norm1", # qwen2vl + "vision_tower.encoder.blocks.{bid}.norm0", # kimi-vl (norm0/norm1) ), MODEL_TENSOR.V_ENC_ATTN_O: ( @@ -1197,6 +1203,7 @@ class TensorNameMap: "vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral-hf "vision_encoder.transformer.layers.{bid}.attention.wo", # pixtral "visual.blocks.{bid}.attn.proj", # qwen2vl + "vision_tower.encoder.blocks.{bid}.wo", # kimi-vl ), MODEL_TENSOR.V_ENC_POST_ATTN_NORM: ( @@ -1209,6 +1216,7 @@ class TensorNameMap: "vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral-hf "vision_encoder.transformer.layers.{bid}.ffn_norm", # pixtral "visual.blocks.{bid}.norm2", # qwen2vl + "vision_tower.encoder.blocks.{bid}.norm1", # kimi-vl (norm0/norm1) ), MODEL_TENSOR.V_ENC_FFN_UP: ( @@ -1221,6 +1229,7 @@ class TensorNameMap: "vision_model.model.layers.{bid}.mlp.fc1", # llama4 "visual.blocks.{bid}.mlp.fc1", # qwen2vl "visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl + "vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1) ), MODEL_TENSOR.V_ENC_FFN_GATE: ( @@ -1239,6 +1248,7 @@ class TensorNameMap: "vision_model.model.layers.{bid}.mlp.fc2", # llama4 "visual.blocks.{bid}.mlp.fc2", # qwen2vl "visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl + "vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1) ), MODEL_TENSOR.V_LAYER_SCALE_1: ( @@ -1263,6 +1273,7 @@ class TensorNameMap: "model.vision_model.post_layernorm", # SmolVLM "vision_model.layernorm_post", # llama4 "visual.merger.ln_q", # qwen2vl + "vision_tower.encoder.final_layernorm", # kimi-vl ), MODEL_TENSOR.V_MM_INP_PROJ: ( @@ -1272,6 +1283,7 @@ class TensorNameMap: MODEL_TENSOR.V_MM_INP_NORM: ( "multi_modal_projector.norm", "multi_modal_projector.layer_norm", + "multi_modal_projector.pre_norm", "pre_mm_projector_norm", ), diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 706ed2e3..664b0c9a 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -135,6 +135,7 @@ enum projector_type { PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx PROJECTOR_TYPE_VOXTRAL, PROJECTOR_TYPE_LFM2, + PROJECTOR_TYPE_KIMIVL, PROJECTOR_TYPE_UNKNOWN, }; @@ -156,6 +157,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_QWEN25O, "qwen2.5o"}, { PROJECTOR_TYPE_VOXTRAL, "voxtral"}, { PROJECTOR_TYPE_LFM2, "lfm2"}, + { PROJECTOR_TYPE_KIMIVL, "kimivl"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 0e76b9c5..e7c516d2 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -526,57 +526,16 @@ struct clip_graph { cur); } else if (ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3) { + // pixel_shuffle // https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578 - const int scale_factor = model.hparams.proj_scale_factor; - const int n_embd = cur->ne[0]; - const int seq = cur->ne[1]; - const int bsz = 1; // batch size, always 1 for now since we don't support batching - const int height = std::sqrt(seq); - const int width = std::sqrt(seq); - GGML_ASSERT(scale_factor != 0); - cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height, bsz); - cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); - cur = ggml_cont_4d(ctx0, cur, - n_embd * scale_factor * scale_factor, - height / scale_factor, - width / scale_factor, - bsz); - cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); - cur = ggml_cont_3d(ctx0, cur, - n_embd * scale_factor * scale_factor, - seq / (scale_factor * scale_factor), - bsz); - + cur = build_patch_merge_permute(cur, scale_factor); cur = ggml_mul_mat(ctx0, model.projection, cur); + } else if (ctx->proj_type() == PROJECTOR_TYPE_LFM2) { // pixel unshuffle block const int scale_factor = model.hparams.proj_scale_factor; - GGML_ASSERT(scale_factor > 1); - - const int n_embd = cur->ne[0]; - int width = img.nx / patch_size; - int height = img.ny / patch_size; - - // pad width and height to factor - const int64_t pad_width = CLIP_ALIGN(width, scale_factor) - width; - const int64_t pad_height = CLIP_ALIGN(height, scale_factor) - height; - cur = ggml_reshape_3d(ctx0, cur, n_embd, width, height); - if (pad_width || pad_height) { - cur = ggml_pad(ctx0, cur, 0, pad_width, pad_height, 0); - width += pad_width; - height += pad_height; - } - - // unshuffle h - cur = ggml_reshape_3d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height); - cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); - - // unshuffle w - cur = ggml_cont_3d(ctx0, cur, n_embd * scale_factor * scale_factor, height / scale_factor, width / scale_factor); - cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); - - cur = ggml_cont_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]); + cur = build_patch_merge_permute(cur, scale_factor); // projection cur = ggml_norm(ctx0, cur, 1e-5); // default nn.LayerNorm @@ -1086,7 +1045,7 @@ struct clip_graph { n_patches_x / scale_factor, n_patches_y / scale_factor, bsz); - cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + //cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); // flatten to 2D cur = ggml_cont_2d(ctx0, cur, n_embd * scale_factor * scale_factor, @@ -1113,6 +1072,67 @@ struct clip_graph { return gf; } + ggml_cgraph * build_kimivl() { + // 2D input positions + ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches); + ggml_set_name(pos_h, "pos_h"); + ggml_set_input(pos_h); + + ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches); + ggml_set_name(pos_w, "pos_w"); + ggml_set_input(pos_w); + + ggml_tensor * learned_pos_embd = resize_position_embeddings(); + + // build ViT with 2D position embeddings + auto add_pos = [&](ggml_tensor * cur, const clip_layer &) { + // first half is X axis and second half is Y axis + return build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false); + }; + + ggml_tensor * inp = build_inp(); + ggml_tensor * cur = build_vit( + inp, n_patches, + NORM_TYPE_NORMAL, + hparams.ffn_op, + learned_pos_embd, + add_pos); + + cb(cur, "vit_out", -1); + + { + // patch_merger + const int scale_factor = model.hparams.proj_scale_factor; + cur = build_patch_merge_permute(cur, scale_factor); + + // projection norm + int proj_inp_dim = cur->ne[0]; + cur = ggml_view_2d(ctx0, cur, + n_embd, cur->ne[1] * scale_factor * scale_factor, + ggml_row_size(cur->type, n_embd), 0); + cur = ggml_norm(ctx0, cur, 1e-5); // default nn.LayerNorm + cur = ggml_mul(ctx0, cur, model.mm_input_norm_w); + cur = ggml_add(ctx0, cur, model.mm_input_norm_b); + cur = ggml_view_2d(ctx0, cur, + proj_inp_dim, cur->ne[1] / scale_factor / scale_factor, + ggml_row_size(cur->type, proj_inp_dim), 0); + cb(cur, "proj_inp_normed", -1); + + // projection mlp + cur = ggml_mul_mat(ctx0, model.mm_1_w, cur); + cur = ggml_add(ctx0, cur, model.mm_1_b); + cur = ggml_gelu(ctx0, cur); + cur = ggml_mul_mat(ctx0, model.mm_2_w, cur); + cur = ggml_add(ctx0, cur, model.mm_2_b); + cb(cur, "proj_out", -1); + } + + // build the graph + ggml_build_forward_expand(gf, cur); + + return gf; + } + // this graph is used by llava, granite and glm // due to having embedding_stack (used by granite), we cannot reuse build_vit ggml_cgraph * build_llava() { @@ -1611,18 +1631,20 @@ private: ggml_tensor * pos_embd = model.position_embeddings; const int height = img.ny / patch_size; const int width = img.nx / patch_size; + const uint32_t mode = GGML_SCALE_MODE_BILINEAR; + const int n_per_side = (int)std::sqrt(pos_embd->ne[1]); + + GGML_ASSERT(pos_embd); - if (!pos_embd || height * width == pos_embd->ne[1]) { + if (height == n_per_side && width == n_per_side) { return pos_embd; } - const int n_pos_embd = std::sqrt(pos_embd->ne[1]); - pos_embd = ggml_reshape_3d(ctx0, pos_embd, n_embd, n_pos_embd, n_pos_embd); // -> (n_embd, n_pos_embd, n_pos_embd) - pos_embd = ggml_permute(ctx0, pos_embd, 2, 0, 1, 3); // -> (n_pos_embd, n_pos_embd, n_embd) - pos_embd = ggml_interpolate(ctx0, pos_embd, width, height, n_embd, 1, 1); // -> (width, height, n_embd) - pos_embd = ggml_reshape_2d(ctx0, pos_embd, height * width, n_embd); // -> (height * width, n_embd) - pos_embd = ggml_transpose(ctx0, pos_embd); // -> (n_embd, height * width) - pos_embd = ggml_cont(ctx0, pos_embd); + pos_embd = ggml_reshape_3d(ctx0, pos_embd, n_embd, n_per_side, n_per_side); // -> (n_embd, n_per_side, n_per_side) + pos_embd = ggml_permute(ctx0, pos_embd, 2, 0, 1, 3); // -> (n_per_side, n_per_side, n_embd) + pos_embd = ggml_interpolate(ctx0, pos_embd, width, height, n_embd, 1, mode); // -> (width, height, n_embd) + pos_embd = ggml_permute(ctx0, pos_embd, 1, 2, 0, 3); // -> (n_embd, width, height) + pos_embd = ggml_cont_2d(ctx0, pos_embd, n_embd, width * height); // -> (n_embd, width * height) return pos_embd; } @@ -2021,6 +2043,39 @@ private: return cur; } + // aka pixel_shuffle / pixel_unshuffle / patch_merger (Kimi-VL) + // support dynamic resolution + ggml_tensor * build_patch_merge_permute(ggml_tensor * cur, int scale_factor) { + GGML_ASSERT(scale_factor > 1); + + const int n_embd = cur->ne[0]; + int width = img.nx / patch_size; + int height = img.ny / patch_size; + + // pad width and height to factor + const int64_t pad_width = CLIP_ALIGN(width, scale_factor) - width; + const int64_t pad_height = CLIP_ALIGN(height, scale_factor) - height; + cur = ggml_reshape_3d(ctx0, cur, n_embd, width, height); + if (pad_width || pad_height) { + cur = ggml_pad(ctx0, cur, 0, pad_width, pad_height, 0); + width += pad_width; + height += pad_height; + } + + // unshuffle h + cur = ggml_reshape_3d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + + // unshuffle w + cur = ggml_cont_3d(ctx0, cur, n_embd * scale_factor * scale_factor, height / scale_factor, width / scale_factor); + cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); + + cur = ggml_cont_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]); + cb(cur, "pixel_shuffle", -1); + + return cur; + } + }; static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch & imgs) { @@ -2063,6 +2118,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { res = graph.build_whisper_enc(); } break; + case PROJECTOR_TYPE_KIMIVL: + { + res = graph.build_kimivl(); + } break; default: { res = graph.build_llava(); @@ -2313,6 +2372,12 @@ struct clip_model_loader { hparams.image_size = 1024; get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false); } break; + case PROJECTOR_TYPE_KIMIVL: + { + hparams.rope_theta = 10000.0f; + hparams.warmup_image_size = hparams.patch_size * 8; + get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false); + } break; case PROJECTOR_TYPE_GEMMA3: { // default value (used by all model sizes in gemma 3 family) @@ -2477,7 +2542,20 @@ struct clip_model_loader { // some models already exported with legacy (incorrect) naming which is quite messy, let's fix it here // note: Qwen model converted from the old surgery script has n_ff = 0, so we cannot use n_ff to check! - if (layer.ff_up_w && layer.ff_down_w && layer.ff_down_w->ne[0] == hparams.n_embd) { + bool is_ffn_swapped = ( + // only old models need this fix + model.proj_type == PROJECTOR_TYPE_MLP + || model.proj_type == PROJECTOR_TYPE_MLP_NORM + || model.proj_type == PROJECTOR_TYPE_LDP + || model.proj_type == PROJECTOR_TYPE_LDPV2 + || model.proj_type == PROJECTOR_TYPE_QWEN2VL + || model.proj_type == PROJECTOR_TYPE_QWEN25VL + || model.proj_type == PROJECTOR_TYPE_GLM_EDGE + || model.proj_type == PROJECTOR_TYPE_GEMMA3 + || model.proj_type == PROJECTOR_TYPE_IDEFICS3 + || model.proj_type == PROJECTOR_TYPE_MINICPMV + ) && layer.ff_up_w && layer.ff_down_w && layer.ff_down_w->ne[0] == hparams.n_embd; + if (is_ffn_swapped) { // swap up and down weights ggml_tensor * tmp = layer.ff_up_w; layer.ff_up_w = layer.ff_down_w; @@ -2486,6 +2564,9 @@ struct clip_model_loader { tmp = layer.ff_up_b; layer.ff_up_b = layer.ff_down_b; layer.ff_down_b = tmp; + if (il == 0) { + LOG_WRN("%s: ffn up/down are swapped\n", __func__); + } } } @@ -2604,6 +2685,7 @@ struct clip_model_loader { model.projection = get_tensor(TN_MM_PROJECTOR); } break; case PROJECTOR_TYPE_LFM2: + case PROJECTOR_TYPE_KIMIVL: { model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM); model.mm_input_norm_b = get_tensor(TN_MM_INP_NORM_B); @@ -3507,7 +3589,9 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str res_imgs->grid_y = inst.grid_size.height; return true; - } else if (ctx->proj_type() == PROJECTOR_TYPE_LFM2) { + } else if ( ctx->proj_type() == PROJECTOR_TYPE_LFM2 + || ctx->proj_type() == PROJECTOR_TYPE_KIMIVL + ) { GGML_ASSERT(params.proj_scale_factor); // smart resize @@ -3708,12 +3792,21 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im case PROJECTOR_TYPE_IDEFICS3: case PROJECTOR_TYPE_INTERNVL: case PROJECTOR_TYPE_LLAMA4: - case PROJECTOR_TYPE_LFM2: { - // both W and H are divided by proj_scale_factor + // both X and Y are downscaled by the scale factor int scale_factor = ctx->model.hparams.proj_scale_factor; n_patches /= (scale_factor * scale_factor); } break; + case PROJECTOR_TYPE_LFM2: + case PROJECTOR_TYPE_KIMIVL: + { + // dynamic size + int scale_factor = ctx->model.hparams.proj_scale_factor; + int out_patch_size = params.patch_size * scale_factor; + int x_patch = CLIP_ALIGN(img->nx, out_patch_size) / out_patch_size; + int y_patch = CLIP_ALIGN(img->ny, out_patch_size) / out_patch_size; + n_patches = x_patch * y_patch; + } break; case PROJECTOR_TYPE_PIXTRAL: { // dynamic size @@ -4096,6 +4189,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima set_input_i32("positions", positions); } break; case PROJECTOR_TYPE_PIXTRAL: + case PROJECTOR_TYPE_KIMIVL: { // set the 2D positions int n_patches_per_col = image_size_width / patch_size; @@ -4250,6 +4344,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { case PROJECTOR_TYPE_QWEN2A: return ctx->model.mm_fc_w->ne[1]; case PROJECTOR_TYPE_LFM2: + case PROJECTOR_TYPE_KIMIVL: return ctx->model.mm_2_w->ne[1]; default: GGML_ABORT("Unknown projector type"); diff --git a/tools/mtmd/tests.sh b/tools/mtmd/tests.sh index 6f8a5f86..c64be036 100755 --- a/tools/mtmd/tests.sh +++ b/tools/mtmd/tests.sh @@ -86,6 +86,7 @@ if [ "$RUN_BIG_TESTS" = true ]; then add_test_vision "ggml-org/InternVL3-14B-Instruct-GGUF:Q4_K_M" add_test_vision "ggml-org/Qwen2.5-Omni-7B-GGUF:Q4_K_M" # add_test_vision "ggml-org/Qwen2.5-VL-32B-Instruct-GGUF:Q4_K_M" # does not work on my mac M3 Ultra + add_test_vision "ggml-org/Kimi-VL-A3B-Thinking-2506-GGUF:Q4_K_M" add_test_audio "ggml-org/ultravox-v0_5-llama-3_1-8b-GGUF:Q4_K_M" add_test_audio "ggml-org/Qwen2.5-Omni-7B-GGUF:Q4_K_M"