"vision_encoder.patch_conv", # pixtral
"vision_model.patch_embedding.linear", # llama 4
"visual.patch_embed.proj", # qwen2vl
+ "vision_tower.patch_embed.proj", # kimi-vl
),
MODEL_TENSOR.V_ENC_EMBD_POS: (
"vpm.embeddings.position_embedding",
"model.vision_model.embeddings.position_embedding", # SmolVLM
"vision_model.positional_embedding_vlm", # llama 4
+ "vision_tower.patch_embed.pos_emb", # kimi-vl
),
MODEL_TENSOR.V_ENC_ATTN_Q: (
"vision_tower.transformer.layers.{bid}.attention.q_proj", # pixtral-hf
"vision_encoder.transformer.layers.{bid}.attention.wq", # pixtral
"visual.blocks.{bid}.attn.q", # qwen2vl, generated
+ "vision_tower.encoder.blocks.{bid}.wq", # kimi-vl, generated
),
MODEL_TENSOR.V_ENC_ATTN_Q_NORM: (
"vision_tower.transformer.layers.{bid}.attention.k_proj", # pixtral-hf
"vision_encoder.transformer.layers.{bid}.attention.wk", # pixtral
"visual.blocks.{bid}.attn.k", # qwen2vl, generated
+ "vision_tower.encoder.blocks.{bid}.wk", # kimi-vl, generated
),
MODEL_TENSOR.V_ENC_ATTN_K_NORM: (
"vision_tower.transformer.layers.{bid}.attention.v_proj", # pixtral-hf
"vision_encoder.transformer.layers.{bid}.attention.wv", # pixtral
"visual.blocks.{bid}.attn.v", # qwen2vl, generated
+ "vision_tower.encoder.blocks.{bid}.wv", # kimi-vl, generated
),
MODEL_TENSOR.V_ENC_INPUT_NORM: (
"vision_encoder.transformer.layers.{bid}.attention_norm", # pixtral
"vision_model.model.layers.{bid}.input_layernorm", # llama4
"visual.blocks.{bid}.norm1", # qwen2vl
+ "vision_tower.encoder.blocks.{bid}.norm0", # kimi-vl (norm0/norm1)
),
MODEL_TENSOR.V_ENC_ATTN_O: (
"vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral-hf
"vision_encoder.transformer.layers.{bid}.attention.wo", # pixtral
"visual.blocks.{bid}.attn.proj", # qwen2vl
+ "vision_tower.encoder.blocks.{bid}.wo", # kimi-vl
),
MODEL_TENSOR.V_ENC_POST_ATTN_NORM: (
"vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral-hf
"vision_encoder.transformer.layers.{bid}.ffn_norm", # pixtral
"visual.blocks.{bid}.norm2", # qwen2vl
+ "vision_tower.encoder.blocks.{bid}.norm1", # kimi-vl (norm0/norm1)
),
MODEL_TENSOR.V_ENC_FFN_UP: (
"vision_model.model.layers.{bid}.mlp.fc1", # llama4
"visual.blocks.{bid}.mlp.fc1", # qwen2vl
"visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl
+ "vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1)
),
MODEL_TENSOR.V_ENC_FFN_GATE: (
"vision_model.model.layers.{bid}.mlp.fc2", # llama4
"visual.blocks.{bid}.mlp.fc2", # qwen2vl
"visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl
+ "vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1)
),
MODEL_TENSOR.V_LAYER_SCALE_1: (
"model.vision_model.post_layernorm", # SmolVLM
"vision_model.layernorm_post", # llama4
"visual.merger.ln_q", # qwen2vl
+ "vision_tower.encoder.final_layernorm", # kimi-vl
),
MODEL_TENSOR.V_MM_INP_PROJ: (
MODEL_TENSOR.V_MM_INP_NORM: (
"multi_modal_projector.norm",
"multi_modal_projector.layer_norm",
+ "multi_modal_projector.pre_norm",
"pre_mm_projector_norm",
),
cur);
} else if (ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3) {
+ // pixel_shuffle
// https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578
-
const int scale_factor = model.hparams.proj_scale_factor;
- const int n_embd = cur->ne[0];
- const int seq = cur->ne[1];
- const int bsz = 1; // batch size, always 1 for now since we don't support batching
- const int height = std::sqrt(seq);
- const int width = std::sqrt(seq);
- GGML_ASSERT(scale_factor != 0);
- cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height, bsz);
- cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
- cur = ggml_cont_4d(ctx0, cur,
- n_embd * scale_factor * scale_factor,
- height / scale_factor,
- width / scale_factor,
- bsz);
- cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
- cur = ggml_cont_3d(ctx0, cur,
- n_embd * scale_factor * scale_factor,
- seq / (scale_factor * scale_factor),
- bsz);
-
+ cur = build_patch_merge_permute(cur, scale_factor);
cur = ggml_mul_mat(ctx0, model.projection, cur);
+
} else if (ctx->proj_type() == PROJECTOR_TYPE_LFM2) {
// pixel unshuffle block
const int scale_factor = model.hparams.proj_scale_factor;
- GGML_ASSERT(scale_factor > 1);
-
- const int n_embd = cur->ne[0];
- int width = img.nx / patch_size;
- int height = img.ny / patch_size;
-
- // pad width and height to factor
- const int64_t pad_width = CLIP_ALIGN(width, scale_factor) - width;
- const int64_t pad_height = CLIP_ALIGN(height, scale_factor) - height;
- cur = ggml_reshape_3d(ctx0, cur, n_embd, width, height);
- if (pad_width || pad_height) {
- cur = ggml_pad(ctx0, cur, 0, pad_width, pad_height, 0);
- width += pad_width;
- height += pad_height;
- }
-
- // unshuffle h
- cur = ggml_reshape_3d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height);
- cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
-
- // unshuffle w
- cur = ggml_cont_3d(ctx0, cur, n_embd * scale_factor * scale_factor, height / scale_factor, width / scale_factor);
- cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
-
- cur = ggml_cont_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]);
+ cur = build_patch_merge_permute(cur, scale_factor);
// projection
cur = ggml_norm(ctx0, cur, 1e-5); // default nn.LayerNorm
n_patches_x / scale_factor,
n_patches_y / scale_factor,
bsz);
- cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
+ //cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
// flatten to 2D
cur = ggml_cont_2d(ctx0, cur,
n_embd * scale_factor * scale_factor,
return gf;
}
+ ggml_cgraph * build_kimivl() {
+ // 2D input positions
+ ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
+ ggml_set_name(pos_h, "pos_h");
+ ggml_set_input(pos_h);
+
+ ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
+ ggml_set_name(pos_w, "pos_w");
+ ggml_set_input(pos_w);
+
+ ggml_tensor * learned_pos_embd = resize_position_embeddings();
+
+ // build ViT with 2D position embeddings
+ auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
+ // first half is X axis and second half is Y axis
+ return build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false);
+ };
+
+ ggml_tensor * inp = build_inp();
+ ggml_tensor * cur = build_vit(
+ inp, n_patches,
+ NORM_TYPE_NORMAL,
+ hparams.ffn_op,
+ learned_pos_embd,
+ add_pos);
+
+ cb(cur, "vit_out", -1);
+
+ {
+ // patch_merger
+ const int scale_factor = model.hparams.proj_scale_factor;
+ cur = build_patch_merge_permute(cur, scale_factor);
+
+ // projection norm
+ int proj_inp_dim = cur->ne[0];
+ cur = ggml_view_2d(ctx0, cur,
+ n_embd, cur->ne[1] * scale_factor * scale_factor,
+ ggml_row_size(cur->type, n_embd), 0);
+ cur = ggml_norm(ctx0, cur, 1e-5); // default nn.LayerNorm
+ cur = ggml_mul(ctx0, cur, model.mm_input_norm_w);
+ cur = ggml_add(ctx0, cur, model.mm_input_norm_b);
+ cur = ggml_view_2d(ctx0, cur,
+ proj_inp_dim, cur->ne[1] / scale_factor / scale_factor,
+ ggml_row_size(cur->type, proj_inp_dim), 0);
+ cb(cur, "proj_inp_normed", -1);
+
+ // projection mlp
+ cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
+ cur = ggml_add(ctx0, cur, model.mm_1_b);
+ cur = ggml_gelu(ctx0, cur);
+ cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
+ cur = ggml_add(ctx0, cur, model.mm_2_b);
+ cb(cur, "proj_out", -1);
+ }
+
+ // build the graph
+ ggml_build_forward_expand(gf, cur);
+
+ return gf;
+ }
+
// this graph is used by llava, granite and glm
// due to having embedding_stack (used by granite), we cannot reuse build_vit
ggml_cgraph * build_llava() {
ggml_tensor * pos_embd = model.position_embeddings;
const int height = img.ny / patch_size;
const int width = img.nx / patch_size;
+ const uint32_t mode = GGML_SCALE_MODE_BILINEAR;
+ const int n_per_side = (int)std::sqrt(pos_embd->ne[1]);
+
+ GGML_ASSERT(pos_embd);
- if (!pos_embd || height * width == pos_embd->ne[1]) {
+ if (height == n_per_side && width == n_per_side) {
return pos_embd;
}
- const int n_pos_embd = std::sqrt(pos_embd->ne[1]);
- pos_embd = ggml_reshape_3d(ctx0, pos_embd, n_embd, n_pos_embd, n_pos_embd); // -> (n_embd, n_pos_embd, n_pos_embd)
- pos_embd = ggml_permute(ctx0, pos_embd, 2, 0, 1, 3); // -> (n_pos_embd, n_pos_embd, n_embd)
- pos_embd = ggml_interpolate(ctx0, pos_embd, width, height, n_embd, 1, 1); // -> (width, height, n_embd)
- pos_embd = ggml_reshape_2d(ctx0, pos_embd, height * width, n_embd); // -> (height * width, n_embd)
- pos_embd = ggml_transpose(ctx0, pos_embd); // -> (n_embd, height * width)
- pos_embd = ggml_cont(ctx0, pos_embd);
+ pos_embd = ggml_reshape_3d(ctx0, pos_embd, n_embd, n_per_side, n_per_side); // -> (n_embd, n_per_side, n_per_side)
+ pos_embd = ggml_permute(ctx0, pos_embd, 2, 0, 1, 3); // -> (n_per_side, n_per_side, n_embd)
+ pos_embd = ggml_interpolate(ctx0, pos_embd, width, height, n_embd, 1, mode); // -> (width, height, n_embd)
+ pos_embd = ggml_permute(ctx0, pos_embd, 1, 2, 0, 3); // -> (n_embd, width, height)
+ pos_embd = ggml_cont_2d(ctx0, pos_embd, n_embd, width * height); // -> (n_embd, width * height)
return pos_embd;
}
return cur;
}
+ // aka pixel_shuffle / pixel_unshuffle / patch_merger (Kimi-VL)
+ // support dynamic resolution
+ ggml_tensor * build_patch_merge_permute(ggml_tensor * cur, int scale_factor) {
+ GGML_ASSERT(scale_factor > 1);
+
+ const int n_embd = cur->ne[0];
+ int width = img.nx / patch_size;
+ int height = img.ny / patch_size;
+
+ // pad width and height to factor
+ const int64_t pad_width = CLIP_ALIGN(width, scale_factor) - width;
+ const int64_t pad_height = CLIP_ALIGN(height, scale_factor) - height;
+ cur = ggml_reshape_3d(ctx0, cur, n_embd, width, height);
+ if (pad_width || pad_height) {
+ cur = ggml_pad(ctx0, cur, 0, pad_width, pad_height, 0);
+ width += pad_width;
+ height += pad_height;
+ }
+
+ // unshuffle h
+ cur = ggml_reshape_3d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height);
+ cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
+
+ // unshuffle w
+ cur = ggml_cont_3d(ctx0, cur, n_embd * scale_factor * scale_factor, height / scale_factor, width / scale_factor);
+ cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
+
+ cur = ggml_cont_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]);
+ cb(cur, "pixel_shuffle", -1);
+
+ return cur;
+ }
+
};
static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
{
res = graph.build_whisper_enc();
} break;
+ case PROJECTOR_TYPE_KIMIVL:
+ {
+ res = graph.build_kimivl();
+ } break;
default:
{
res = graph.build_llava();
hparams.image_size = 1024;
get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false);
} break;
+ case PROJECTOR_TYPE_KIMIVL:
+ {
+ hparams.rope_theta = 10000.0f;
+ hparams.warmup_image_size = hparams.patch_size * 8;
+ get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
+ } break;
case PROJECTOR_TYPE_GEMMA3:
{
// default value (used by all model sizes in gemma 3 family)
// some models already exported with legacy (incorrect) naming which is quite messy, let's fix it here
// note: Qwen model converted from the old surgery script has n_ff = 0, so we cannot use n_ff to check!
- if (layer.ff_up_w && layer.ff_down_w && layer.ff_down_w->ne[0] == hparams.n_embd) {
+ bool is_ffn_swapped = (
+ // only old models need this fix
+ model.proj_type == PROJECTOR_TYPE_MLP
+ || model.proj_type == PROJECTOR_TYPE_MLP_NORM
+ || model.proj_type == PROJECTOR_TYPE_LDP
+ || model.proj_type == PROJECTOR_TYPE_LDPV2
+ || model.proj_type == PROJECTOR_TYPE_QWEN2VL
+ || model.proj_type == PROJECTOR_TYPE_QWEN25VL
+ || model.proj_type == PROJECTOR_TYPE_GLM_EDGE
+ || model.proj_type == PROJECTOR_TYPE_GEMMA3
+ || model.proj_type == PROJECTOR_TYPE_IDEFICS3
+ || model.proj_type == PROJECTOR_TYPE_MINICPMV
+ ) && layer.ff_up_w && layer.ff_down_w && layer.ff_down_w->ne[0] == hparams.n_embd;
+ if (is_ffn_swapped) {
// swap up and down weights
ggml_tensor * tmp = layer.ff_up_w;
layer.ff_up_w = layer.ff_down_w;
tmp = layer.ff_up_b;
layer.ff_up_b = layer.ff_down_b;
layer.ff_down_b = tmp;
+ if (il == 0) {
+ LOG_WRN("%s: ffn up/down are swapped\n", __func__);
+ }
}
}
model.projection = get_tensor(TN_MM_PROJECTOR);
} break;
case PROJECTOR_TYPE_LFM2:
+ case PROJECTOR_TYPE_KIMIVL:
{
model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM);
model.mm_input_norm_b = get_tensor(TN_MM_INP_NORM_B);
res_imgs->grid_y = inst.grid_size.height;
return true;
- } else if (ctx->proj_type() == PROJECTOR_TYPE_LFM2) {
+ } else if ( ctx->proj_type() == PROJECTOR_TYPE_LFM2
+ || ctx->proj_type() == PROJECTOR_TYPE_KIMIVL
+ ) {
GGML_ASSERT(params.proj_scale_factor);
// smart resize
case PROJECTOR_TYPE_IDEFICS3:
case PROJECTOR_TYPE_INTERNVL:
case PROJECTOR_TYPE_LLAMA4:
- case PROJECTOR_TYPE_LFM2:
{
- // both W and H are divided by proj_scale_factor
+ // both X and Y are downscaled by the scale factor
int scale_factor = ctx->model.hparams.proj_scale_factor;
n_patches /= (scale_factor * scale_factor);
} break;
+ case PROJECTOR_TYPE_LFM2:
+ case PROJECTOR_TYPE_KIMIVL:
+ {
+ // dynamic size
+ int scale_factor = ctx->model.hparams.proj_scale_factor;
+ int out_patch_size = params.patch_size * scale_factor;
+ int x_patch = CLIP_ALIGN(img->nx, out_patch_size) / out_patch_size;
+ int y_patch = CLIP_ALIGN(img->ny, out_patch_size) / out_patch_size;
+ n_patches = x_patch * y_patch;
+ } break;
case PROJECTOR_TYPE_PIXTRAL:
{
// dynamic size
set_input_i32("positions", positions);
} break;
case PROJECTOR_TYPE_PIXTRAL:
+ case PROJECTOR_TYPE_KIMIVL:
{
// set the 2D positions
int n_patches_per_col = image_size_width / patch_size;
case PROJECTOR_TYPE_QWEN2A:
return ctx->model.mm_fc_w->ne[1];
case PROJECTOR_TYPE_LFM2:
+ case PROJECTOR_TYPE_KIMIVL:
return ctx->model.mm_2_w->ne[1];
default:
GGML_ABORT("Unknown projector type");