]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
clip : improve projector naming (#13118)
authorXuan-Son Nguyen <redacted>
Sat, 26 Apr 2025 20:39:47 +0000 (22:39 +0200)
committerGitHub <redacted>
Sat, 26 Apr 2025 20:39:47 +0000 (22:39 +0200)
* clip : improve projector naming

* no more kv has_llava_projector

* rm unused kv

* rm more unused

examples/llava/clip-impl.h
examples/llava/clip.cpp
examples/llava/clip.h

index 53ac38130476557a34a63d3f8e611b4a316d9cd3..16d0a8efc56ae0c0cf88ca0bae89c6521c09b90d 100644 (file)
 #define KEY_FTYPE               "general.file_type"
 #define KEY_NAME                "general.name"
 #define KEY_DESCRIPTION         "general.description"
-#define KEY_HAS_TEXT_ENC        "clip.has_text_encoder"
-#define KEY_HAS_VIS_ENC         "clip.has_vision_encoder"
-#define KEY_HAS_LLAVA_PROJ      "clip.has_llava_projector"
-#define KEY_HAS_MINICPMV_PROJ   "clip.has_minicpmv_projector"
-#define KEY_HAS_GLM_PROJ        "clip.has_glm_projector"
 #define KEY_MINICPMV_VERSION    "clip.minicpmv_version"
-#define KEY_HAS_QWEN2VL_MERGER  "clip.has_qwen2vl_merger"
 #define KEY_USE_GELU            "clip.use_gelu"
 #define KEY_USE_SILU            "clip.use_silu"
-#define KEY_N_EMBD              "clip.%s.embedding_length"
-#define KEY_N_FF                "clip.%s.feed_forward_length"
-#define KEY_N_BLOCK             "clip.%s.block_count"
-#define KEY_N_HEAD              "clip.%s.attention.head_count"
-#define KEY_LAYER_NORM_EPS      "clip.%s.attention.layer_norm_epsilon"
-#define KEY_PROJ_DIM            "clip.%s.projection_dim"
-#define KEY_TOKENS              "tokenizer.ggml.tokens"
+#define KEY_N_EMBD              "clip.vision.embedding_length"
+#define KEY_N_FF                "clip.vision.feed_forward_length"
+#define KEY_N_BLOCK             "clip.vision.block_count"
+#define KEY_N_HEAD              "clip.vision.attention.head_count"
+#define KEY_LAYER_NORM_EPS      "clip.vision.attention.layer_norm_epsilon"
+#define KEY_PROJ_DIM            "clip.vision.projection_dim"
 #define KEY_IMAGE_SIZE          "clip.vision.image_size"
 #define KEY_PATCH_SIZE          "clip.vision.patch_size"
 #define KEY_IMAGE_MEAN          "clip.vision.image_mean"
@@ -96,9 +89,9 @@ enum projector_type {
     PROJECTOR_TYPE_MLP_NORM,
     PROJECTOR_TYPE_LDP,
     PROJECTOR_TYPE_LDPV2,
-    PROJECTOR_TYPE_RESAMPLER,
+    PROJECTOR_TYPE_MINICPMV,
     PROJECTOR_TYPE_GLM_EDGE,
-    PROJECTOR_TYPE_MERGER,
+    PROJECTOR_TYPE_QWEN2VL,
     PROJECTOR_TYPE_GEMMA3,
     PROJECTOR_TYPE_IDEFICS3,
     PROJECTOR_TYPE_PIXTRAL,
@@ -109,9 +102,9 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
     { PROJECTOR_TYPE_MLP,       "mlp" },
     { PROJECTOR_TYPE_LDP,       "ldp" },
     { PROJECTOR_TYPE_LDPV2,     "ldpv2"},
-    { PROJECTOR_TYPE_RESAMPLER, "resampler"},
+    { PROJECTOR_TYPE_MINICPMV,  "resampler"},
     { PROJECTOR_TYPE_GLM_EDGE,  "adapter"},
-    { PROJECTOR_TYPE_MERGER,    "qwen2vl_merger"},
+    { PROJECTOR_TYPE_QWEN2VL,   "qwen2vl_merger"},
     { PROJECTOR_TYPE_GEMMA3,    "gemma3"},
     { PROJECTOR_TYPE_IDEFICS3,  "idefics3"},
     { PROJECTOR_TYPE_PIXTRAL,   "pixtral"},
index da8a590f0e5638169078142b264ac340d793af78..e8c01c68a9779e8a6ac85420b8be5a54d6d7246d 100644 (file)
@@ -308,13 +308,8 @@ struct clip_vision_model {
 };
 
 struct clip_ctx {
-    bool has_text_encoder    = false;
-    bool has_vision_encoder  = false;
     bool has_llava_projector = false;
-    bool has_minicpmv_projector = false;
-    bool has_glm_projector = false;
-    bool has_qwen2vl_merger = false;
-    int minicpmv_version = 2;
+    int minicpmv_version = 0;
 
     struct clip_vision_model vision_model;
     projector_type proj_type = PROJECTOR_TYPE_MLP;
@@ -373,23 +368,20 @@ struct clip_ctx {
     }
 };
 
-static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
+static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_image_f32 & img) {
     const auto & model = ctx->vision_model;
     const auto & hparams = model.hparams;
 
-    const int image_size = hparams.image_size;
-    int image_size_width  = image_size;
-    int image_size_height = image_size;
+    int image_size_width  = img.nx;
+    int image_size_height = img.ny;
 
-    const int patch_size           = hparams.patch_size;
-    const int num_patches          = ((image_size_width / patch_size) * (image_size_height / patch_size));
-    const int hidden_size          = hparams.hidden_size;
-    const int n_head               = hparams.n_head;
-    const int d_head               = hidden_size / n_head;
-    const int n_layer              = hparams.n_layer;
-    const float eps                = hparams.eps;
-
-    GGML_ASSERT(imgs.entries.size() == 1); // batch_size == 1
+    const int patch_size  = hparams.patch_size;
+    const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
+    const int hidden_size = hparams.hidden_size;
+    const int n_head      = hparams.n_head;
+    const int d_head      = hidden_size / n_head;
+    const int n_layer     = hparams.n_layer;
+    const float eps       = hparams.eps;
 
     struct ggml_init_params params = {
         /*.mem_size   =*/ ctx->buf_compute_meta.size(),
@@ -621,15 +613,14 @@ static ggml_tensor * build_rope_2d(
     return cur;
 }
 
-static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
+static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_image_f32 & img) {
     const auto & model = ctx->vision_model;
     const auto & hparams = model.hparams;
 
     GGML_ASSERT(ctx->proj_type == PROJECTOR_TYPE_PIXTRAL);
-    GGML_ASSERT(imgs.entries.size() == 1); // batch_size == 1
 
-    int image_size_width  = imgs.entries[0]->nx;
-    int image_size_height = imgs.entries[0]->ny;
+    int image_size_width  = img.nx;
+    int image_size_height = img.ny;
 
     const int patch_size  = hparams.patch_size;
     const int n_patches_x = image_size_width  / patch_size;
@@ -772,18 +763,14 @@ static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_i
 }
 
 static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) {
-    if (!ctx->has_vision_encoder) {
-        LOG_ERR("This gguf file seems to have no vision encoder\n");
-        return nullptr;
-    }
-
     const auto & model = ctx->vision_model;
     const auto & hparams = model.hparams;
 
     const int image_size = hparams.image_size;
     int image_size_width  = image_size;
     int image_size_height = image_size;
-    if (ctx->has_minicpmv_projector) {
+
+    if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
         LOG_DBG("%s: %d %d\n", __func__, load_image_size.width, load_image_size.height);
         image_size_width  = load_image_size.width;
         image_size_height = load_image_size.height;
@@ -792,7 +779,8 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
             image_size_height = imgs.entries[0]->ny;
         }
     }
-    else if (ctx->has_qwen2vl_merger) {
+
+    else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
         // use the image's native resolution when image is avaible
         if (is_inf) {
         // if (imgs->data->nx && imgs->data->ny) {
@@ -800,12 +788,13 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
             image_size_height = imgs.entries[0]->ny;
         }
     }
+
     const int patch_size           = hparams.patch_size;
     const int num_patches          = ((image_size_width / patch_size) * (image_size_height / patch_size));
     const int patches_w            = image_size_width / patch_size;
     const int patches_h            = image_size_height / patch_size;
     const int num_positions        = num_patches + (model.class_embedding ? 1 : 0);
-    const int num_position_ids     = ctx->has_qwen2vl_merger ? num_positions * 4 : num_positions;
+    const int num_position_ids     = ctx->proj_type == PROJECTOR_TYPE_QWEN2VL ? num_positions * 4 : num_positions;
     const int hidden_size          = hparams.hidden_size;
     const int n_head               = hparams.n_head;
     const int d_head               = hidden_size / n_head;
@@ -814,7 +803,9 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
 
     const int batch_size = imgs.entries.size();
 
-    if (ctx->has_llava_projector || ctx->has_minicpmv_projector || ctx->has_glm_projector) {
+    if (ctx->has_llava_projector
+            || ctx->proj_type == PROJECTOR_TYPE_MINICPMV
+            || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
         GGML_ASSERT(batch_size == 1);
     }
 
@@ -835,8 +826,8 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
 
     struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
 
-    if (ctx->has_qwen2vl_merger) {
-        GGML_ASSERT(image_size_width % (patch_size * 2) == 0);
+    if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
+        GGML_ASSERT(image_size_width  % (patch_size * 2) == 0);
         GGML_ASSERT(image_size_height % (patch_size * 2) == 0);
 
         auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
@@ -865,29 +856,26 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
     struct ggml_tensor * embeddings = inp;
     struct ggml_tensor * pos_embed = nullptr;
 
-    if (ctx->has_llava_projector) {
-        // concat class_embeddings and patch_embeddings
-        if (model.class_embedding) {
-            embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
-            ggml_set_name(embeddings, "embeddings");
-            ggml_set_input(embeddings);
-            embeddings = ggml_acc(ctx0, embeddings, model.class_embedding,
-                    embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
-            embeddings = ggml_acc(ctx0, embeddings, inp,
-                    embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
-        }
+    // concat class_embeddings and patch_embeddings
+    if (model.class_embedding) {
+        embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
+        embeddings = ggml_scale(ctx0, embeddings, 0.0f); // set to all zeros
+        embeddings = ggml_acc(ctx0, embeddings, model.class_embedding,
+                embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
+        embeddings = ggml_acc(ctx0, embeddings, inp,
+                embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
     }
 
     struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids);
     ggml_set_name(positions, "positions");
     ggml_set_input(positions);
 
-    if (!ctx->has_qwen2vl_merger) { // qwen2vl use rope position embedding
+    if (ctx->proj_type != PROJECTOR_TYPE_QWEN2VL) { // qwen2vl does NOT use learned position embeddings
         embeddings =
             ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions));
     }
 
-    if (ctx->has_minicpmv_projector) {
+    if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
         int pos_w = image_size_width/patch_size;
         int pos_h = image_size_height/patch_size;
         if (ctx->minicpmv_version == 2) {
@@ -941,7 +929,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
                 ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b);
 
             Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size);
-            if (ctx->has_qwen2vl_merger) {
+            if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
                 Q = ggml_rope_multi(
                     ctx0, Q, positions, nullptr,
                     d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
@@ -953,7 +941,7 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
                 ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b);
 
             K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
-            if (ctx->has_qwen2vl_merger) {
+            if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
                 K = ggml_rope_multi(
                     ctx0, K, positions, nullptr,
                     d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
@@ -1218,106 +1206,98 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
         }
     }
     // minicpmv projector
-    else if (ctx->has_minicpmv_projector)
-    {
-        if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
-            struct ggml_tensor * q = model.mm_model_query;
-            { // layernorm
-                q = ggml_norm(ctx0, q, eps);
-                q = ggml_add(ctx0, ggml_mul(ctx0, q, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
+    else if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
+        struct ggml_tensor * q = model.mm_model_query;
+        { // layernorm
+            q = ggml_norm(ctx0, q, eps);
+            q = ggml_add(ctx0, ggml_mul(ctx0, q, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
+        }
+        struct ggml_tensor * v = ggml_mul_mat(ctx0, model.mm_model_kv_proj, embeddings);
+        { // layernorm
+            v = ggml_norm(ctx0, v, eps);
+            v = ggml_add(ctx0, ggml_mul(ctx0, v, model.mm_model_ln_kv_w), model.mm_model_ln_kv_b);
+        }
+        struct ggml_tensor * k;
+        { // position
+            // q = ggml_add(ctx0, q, model.mm_model_pos_embed);
+            k = ggml_add(ctx0, v, pos_embed);
+        }
+
+        { // attention
+            int hidden_size = 4096;
+            const int d_head = 128;
+            int n_head = hidden_size/d_head;
+            int num_query = 96;
+            if (ctx->minicpmv_version == 2) {
+                hidden_size = 4096;
+                n_head = hidden_size/d_head;
+                num_query = 96;
             }
-            struct ggml_tensor * v = ggml_mul_mat(ctx0, model.mm_model_kv_proj, embeddings);
-            { // layernorm
-                v = ggml_norm(ctx0, v, eps);
-                v = ggml_add(ctx0, ggml_mul(ctx0, v, model.mm_model_ln_kv_w), model.mm_model_ln_kv_b);
+            else if (ctx->minicpmv_version == 3) {
+                hidden_size = 3584;
+                n_head = hidden_size/d_head;
+                num_query = 64;
             }
-            struct ggml_tensor * k;
-            { // position
-                // q = ggml_add(ctx0, q, model.mm_model_pos_embed);
-                k = ggml_add(ctx0, v, pos_embed);
+            else if (ctx->minicpmv_version == 4) {
+                hidden_size = 3584;
+                n_head = hidden_size/d_head;
+                num_query = 64;
             }
 
-            { // attention
-                int hidden_size = 4096;
-                const int d_head = 128;
-                int n_head = hidden_size/d_head;
-                int num_query = 96;
-                if (ctx->minicpmv_version == 2) {
-                    hidden_size = 4096;
-                    n_head = hidden_size/d_head;
-                    num_query = 96;
-                }
-                else if (ctx->minicpmv_version == 3) {
-                    hidden_size = 3584;
-                    n_head = hidden_size/d_head;
-                    num_query = 64;
-                }
-                else if (ctx->minicpmv_version == 4) {
-                    hidden_size = 3584;
-                    n_head = hidden_size/d_head;
-                    num_query = 64;
-                }
+            struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b);
+            struct ggml_tensor * K = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k), model.mm_model_attn_k_b);
+            struct ggml_tensor * V = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v), model.mm_model_attn_v_b);
+            // permute
+            Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_query, batch_size);
+            Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
+            Q = ggml_reshape_3d(ctx0, Q, d_head, num_query, n_head * batch_size);
+            K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
+            K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
+            K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
+            V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size);
+            V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
+            V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
+            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+            KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
+            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
+            KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_query, n_head, batch_size);
+            KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+            KQV = ggml_cont_3d(ctx0, KQV, hidden_size, num_query, batch_size);
 
-                struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b);
-                struct ggml_tensor * K = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k), model.mm_model_attn_k_b);
-                struct ggml_tensor * V = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v), model.mm_model_attn_v_b);
-                // permute
-                Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_query, batch_size);
-                Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
-                Q = ggml_reshape_3d(ctx0, Q, d_head, num_query, n_head * batch_size);
-                K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
-                K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
-                K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
-                V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size);
-                V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
-                V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
-                struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
-                KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
-                struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
-                KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_query, n_head, batch_size);
-                KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
-                KQV = ggml_cont_3d(ctx0, KQV, hidden_size, num_query, batch_size);
-
-                embeddings = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_o_w, KQV), model.mm_model_attn_o_b);
-            }
-            { // layernorm
-                embeddings = ggml_norm(ctx0, embeddings, eps);
-                embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_post_w), model.mm_model_ln_post_b);
-            }
-            embeddings = ggml_mul_mat(ctx0, model.mm_model_proj, embeddings);
+            embeddings = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_o_w, KQV), model.mm_model_attn_o_b);
         }
-        else {
-            GGML_ASSERT(false);
+        { // layernorm
+            embeddings = ggml_norm(ctx0, embeddings, eps);
+            embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_post_w), model.mm_model_ln_post_b);
         }
+        embeddings = ggml_mul_mat(ctx0, model.mm_model_proj, embeddings);
     }
+
     // glm projector
-    else if (ctx->has_glm_projector) {
-        if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
-            size_t gridsz = (size_t)sqrt(embeddings->ne[1]);
-            embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings,1,0,2,3));
-            embeddings = ggml_reshape_3d(ctx0, embeddings, gridsz, gridsz, embeddings->ne[1]);
-            embeddings = ggml_conv_2d(ctx0, model.mm_model_adapter_conv_w, embeddings, 2, 2, 0, 0, 1, 1);
-            embeddings = ggml_reshape_3d(ctx0, embeddings,embeddings->ne[0]*embeddings->ne[1] , embeddings->ne[2], batch_size);
-            embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings, 1, 0, 2, 3));
-            embeddings = ggml_add(ctx0, embeddings, model.mm_model_adapter_conv_b);
-            //GLU
-            {
-                embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
-                embeddings = ggml_norm(ctx0, embeddings, eps);
-                embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
-                embeddings = ggml_gelu_inplace(ctx0, embeddings);
-                struct ggml_tensor * x = embeddings;
-                embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings);
-                x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x);
-                embeddings = ggml_silu_inplace(ctx0, embeddings);
-                embeddings = ggml_mul(ctx0, embeddings,x);
-                embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings);
-            }
-        } else {
-            GGML_ABORT("fatal error");
+    else if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
+        size_t gridsz = (size_t)sqrt(embeddings->ne[1]);
+        embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings,1,0,2,3));
+        embeddings = ggml_reshape_3d(ctx0, embeddings, gridsz, gridsz, embeddings->ne[1]);
+        embeddings = ggml_conv_2d(ctx0, model.mm_model_adapter_conv_w, embeddings, 2, 2, 0, 0, 1, 1);
+        embeddings = ggml_reshape_3d(ctx0, embeddings,embeddings->ne[0]*embeddings->ne[1] , embeddings->ne[2], batch_size);
+        embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings, 1, 0, 2, 3));
+        embeddings = ggml_add(ctx0, embeddings, model.mm_model_adapter_conv_b);
+        // GLU
+        {
+            embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
+            embeddings = ggml_norm(ctx0, embeddings, eps);
+            embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
+            embeddings = ggml_gelu_inplace(ctx0, embeddings);
+            struct ggml_tensor * x = embeddings;
+            embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings);
+            x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x);
+            embeddings = ggml_silu_inplace(ctx0, embeddings);
+            embeddings = ggml_mul(ctx0, embeddings,x);
+            embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings);
         }
     }
-    else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
+
+    else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
         embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size);
 
         embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
@@ -1343,11 +1323,13 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
         case PROJECTOR_TYPE_GEMMA3:
         case PROJECTOR_TYPE_IDEFICS3:
             {
-                res = clip_image_build_graph_siglip(ctx, imgs);
+                GGML_ASSERT(imgs.entries.size() == 1);
+                res = clip_image_build_graph_siglip(ctx, *imgs.entries[0]);
             } break;
         case PROJECTOR_TYPE_PIXTRAL:
             {
-                res = clip_image_build_graph_pixtral(ctx, imgs);
+                GGML_ASSERT(imgs.entries.size() == 1);
+                res = clip_image_build_graph_pixtral(ctx, *imgs.entries[0]);
             } break;
         default:
             {
@@ -1419,8 +1401,8 @@ struct clip_model_loader {
         auto & hparams = ctx_clip.vision_model.hparams;
 
         // projector type
+        std::string proj_type;
         {
-            std::string proj_type;
             get_string(KEY_PROJ_TYPE, proj_type, false);
             if (!proj_type.empty()) {
                 ctx_clip.proj_type = clip_projector_type_from_string(proj_type);
@@ -1432,33 +1414,27 @@ struct clip_model_loader {
 
         // other hparams
         {
-            get_bool(KEY_HAS_TEXT_ENC, ctx_clip.has_text_encoder, false);
-            get_bool(KEY_HAS_VIS_ENC, ctx_clip.has_vision_encoder, false);
-            GGML_ASSERT(ctx_clip.has_vision_encoder);
-            GGML_ASSERT(!ctx_clip.has_text_encoder);
-
-            // legacy keys, use KEY_PROJ_TYPE instead
-            get_bool(KEY_HAS_LLAVA_PROJ, ctx_clip.has_llava_projector, false);
-            get_bool(KEY_HAS_MINICPMV_PROJ, ctx_clip.has_minicpmv_projector, false);
             get_i32(KEY_MINICPMV_VERSION, ctx_clip.minicpmv_version, false);
-            get_bool(KEY_HAS_GLM_PROJ, ctx_clip.has_glm_projector, false);
-            get_bool(KEY_HAS_QWEN2VL_MERGER, ctx_clip.has_qwen2vl_merger, false);
-            // !!! do NOT extend the list above, use KEY_PROJ_TYPE instead
 
             get_bool(KEY_USE_GELU, ctx_clip.use_gelu, false);
             get_bool(KEY_USE_SILU, ctx_clip.use_silu, false);
 
-            get_u32(string_format(KEY_N_EMBD,         "vision"), hparams.hidden_size);
-            get_u32(string_format(KEY_N_HEAD,         "vision"), hparams.n_head);
-            get_u32(string_format(KEY_N_FF,           "vision"), hparams.n_intermediate);
-            get_u32(string_format(KEY_N_BLOCK,        "vision"), hparams.n_layer);
-            get_u32(string_format(KEY_PROJ_DIM,       "vision"), hparams.projection_dim);
-            get_f32(string_format(KEY_LAYER_NORM_EPS, "vision"), hparams.eps);
-            get_u32(KEY_IMAGE_SIZE, hparams.image_size);
-            get_u32(KEY_PATCH_SIZE, hparams.patch_size);
-            get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false);
+            get_u32(KEY_N_EMBD,         hparams.hidden_size);
+            get_u32(KEY_N_HEAD,         hparams.n_head);
+            get_u32(KEY_N_FF,           hparams.n_intermediate);
+            get_u32(KEY_N_BLOCK,        hparams.n_layer);
+            get_u32(KEY_PROJ_DIM,       hparams.projection_dim);
+            get_f32(KEY_LAYER_NORM_EPS, hparams.eps);
+            get_u32(KEY_IMAGE_SIZE,     hparams.image_size);
+            get_u32(KEY_PATCH_SIZE,     hparams.patch_size);
+            get_u32(KEY_IMAGE_CROP_RESOLUTION,    hparams.image_crop_resolution, false);
             get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false);
 
+            ctx_clip.has_llava_projector = ctx_clip.proj_type == PROJECTOR_TYPE_MLP
+                                        || ctx_clip.proj_type == PROJECTOR_TYPE_MLP_NORM
+                                        || ctx_clip.proj_type == PROJECTOR_TYPE_LDP
+                                        || ctx_clip.proj_type == PROJECTOR_TYPE_LDPV2;
+
             {
                 std::string mm_patch_merge_type;
                 get_string(KEY_MM_PATCH_MERGE_TYPE, mm_patch_merge_type, false);
@@ -1491,32 +1467,56 @@ struct clip_model_loader {
             for (auto & layer : vision_feature_layer) {
                 hparams.vision_feature_layer.insert(layer);
             }
+
             // Calculate the deepest feature layer based on hparams and projector type
-            ctx_clip.max_feature_layer = get_deepest_feature_layer(&ctx_clip);
+            // NOTE: This is only used by build_graph_legacy()
+            {
+                // Get the index of the second to last layer; this is the default for models that have a llava projector
+                int n_layer = hparams.n_layer - 1;
+                int deepest_feature_layer = -1;
+
+                if (ctx_clip.proj_type == PROJECTOR_TYPE_MINICPMV
+                        || ctx_clip.proj_type == PROJECTOR_TYPE_GLM_EDGE
+                        || ctx_clip.proj_type == PROJECTOR_TYPE_QWEN2VL) {
+                    n_layer += 1;
+                }
+
+                // If we set explicit vision feature layers, only go up to the deepest one
+                // NOTE: only used by granite-vision models for now
+                for (const auto & feature_layer : hparams.vision_feature_layer) {
+                    if (feature_layer > deepest_feature_layer) {
+                        deepest_feature_layer = feature_layer;
+                    }
+                }
+                ctx_clip.max_feature_layer = deepest_feature_layer < 0 ? n_layer : deepest_feature_layer;
+            }
+
+            // model-specific params
+            switch (ctx_clip.proj_type) {
+                case PROJECTOR_TYPE_MINICPMV:
+                    {
+                        if (ctx_clip.minicpmv_version == 0) {
+                            ctx_clip.minicpmv_version = 2; // default to 2 if not set
+                        }
+                    } break;
+                case PROJECTOR_TYPE_IDEFICS3:
+                    {
+                        get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
+                    } break;
+                case PROJECTOR_TYPE_PIXTRAL:
+                    {
+                        hparams.rope_theta = 10000.0f;
+                    } break;
+                default:
+                    break;
+            }
 
-            LOG_INF("%s: text_encoder:       %d\n", __func__, ctx_clip.has_text_encoder);
-            LOG_INF("%s: vision_encoder:     %d\n", __func__, ctx_clip.has_vision_encoder);
-            LOG_INF("%s: llava_projector:    %d\n", __func__, ctx_clip.has_llava_projector);
-            LOG_INF("%s: minicpmv_projector: %d\n", __func__, ctx_clip.has_minicpmv_projector);
+            LOG_INF("%s: projector:          %s\n", __func__, proj_type.c_str());
+            LOG_INF("%s: has_llava_proj:     %d\n", __func__, ctx_clip.has_llava_projector);
             LOG_INF("%s: minicpmv_version:   %d\n", __func__, ctx_clip.minicpmv_version);
-            LOG_INF("%s: glm_projector:      %d\n", __func__, ctx_clip.has_glm_projector);
             LOG_INF("%s: model size:         %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0);
             LOG_INF("%s: metadata size:      %.2f MiB\n", __func__, ggml_get_mem_size(ctx_meta.get()) / 1024.0 / 1024.0);
         }
-
-        // model-specific params
-        switch (ctx_clip.proj_type) {
-            case PROJECTOR_TYPE_IDEFICS3:
-                {
-                    get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
-                } break;
-            case PROJECTOR_TYPE_PIXTRAL:
-                {
-                    hparams.rope_theta = 10000.0f;
-                } break;
-            default:
-                break;
-        }
     }
 
     void load_tensors() {
@@ -1569,9 +1569,6 @@ struct clip_model_loader {
         vision_model.patch_bias = get_tensor(TN_PATCH_BIAS, false);
         vision_model.patch_embeddings_0 = get_tensor(TN_PATCH_EMBD,   false);
         vision_model.patch_embeddings_1 = get_tensor(TN_PATCH_EMBD_1, false);
-        if (vision_model.patch_embeddings_1 == nullptr) {
-            ctx_clip.has_qwen2vl_merger = false;
-        }
 
         vision_model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, "v"), false);
 
@@ -1669,7 +1666,7 @@ struct clip_model_loader {
                     vision_model.mm_model_peg_0_w = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "weight"));
                     vision_model.mm_model_peg_0_b = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "bias"));
                 } break;
-            case PROJECTOR_TYPE_RESAMPLER:
+            case PROJECTOR_TYPE_MINICPMV:
                 {
                     // vision_model.mm_model_pos_embed = get_tensor(new_clip->ctx_data, TN_MINICPMV_POS_EMBD);
                     vision_model.mm_model_pos_embed_k = get_tensor(TN_MINICPMV_POS_EMBD_K);
@@ -1702,7 +1699,7 @@ struct clip_model_loader {
                     vision_model.mm_model_mlp_2_w = get_tensor(string_format(TN_GLM_ADAPTER_GATE,"weight"));
                     vision_model.mm_model_mlp_3_w = get_tensor(string_format(TN_GLM_ADAPTER_D_4H_2_H,"weight"));
                 } break;
-            case PROJECTOR_TYPE_MERGER:
+            case PROJECTOR_TYPE_QWEN2VL:
                 {
                     vision_model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"));
                     vision_model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"));
@@ -2479,11 +2476,6 @@ int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip) {
 // returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector
 // res_imgs memory is being allocated here, previous allocations will be freed if found
 bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) {
-    if (!ctx->has_vision_encoder) {
-        LOG_ERR("%s: This gguf file seems to have no vision encoder\n", __func__);
-        return false;
-    }
-
     clip_image_size original_size{img->nx, img->ny};
     bool pad_to_square = true;
     auto & params = ctx->vision_model.hparams;
@@ -2504,7 +2496,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
         }
         return true;
     }
-    else if (ctx->has_qwen2vl_merger) {
+    else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
         clip_image_u8 resized;
         auto patch_size = clip_get_patch_size(ctx) * 2;
         int nx = ceil((float)img->nx / patch_size) * patch_size;
@@ -2518,7 +2510,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
         res_imgs->entries.push_back(std::move(img_f32));
         return true;
     }
-    else if (ctx->has_glm_projector
+    else if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE
             || ctx->proj_type == PROJECTOR_TYPE_GEMMA3
             || ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
         clip_image_u8 resized_image;
@@ -2646,7 +2638,7 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i
 
     if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2 || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
         n_patches /= 4;
-    } else if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
+    } else if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
         if (ctx->minicpmv_version == 2) {
             n_patches = 96;
         }
@@ -2656,7 +2648,10 @@ int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * i
         else if (ctx->minicpmv_version == 4) {
             n_patches = 64;
         }
-    } else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
+        else {
+            GGML_ABORT("Unknown minicpmv version");
+        }
+    } else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
         int patch_size = params.patch_size * 2;
         int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
         int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0);
@@ -2761,11 +2756,6 @@ static std::vector<std::vector<float>> get_2d_sincos_pos_embed(int embed_dim, co
 }
 
 bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f32 * img, float * vec) {
-    if (!ctx->has_vision_encoder) {
-        LOG_ERR("%s: This gguf file seems to have no vision encoder\n", __func__);
-        return false;
-    }
-
     clip_image_f32_batch imgs;
     clip_image_f32_ptr img_copy(clip_image_f32_init());
     *img_copy = *img;
@@ -2776,20 +2766,11 @@ bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f3
 
 bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs_c_ptr, float * vec) {
     const clip_image_f32_batch & imgs = *imgs_c_ptr;
-
-    if (!ctx->has_vision_encoder) {
-        LOG_ERR("%s: This gguf file seems to have no vision encoder\n", __func__);
-        return false;
-    }
-
     int batch_size = imgs.entries.size();
-    if (ctx->has_llava_projector) {
-        GGML_ASSERT(batch_size == 1); // TODO: support multiple images
-    }
-    if (ctx->has_minicpmv_projector) {
-        GGML_ASSERT(batch_size == 1);
-    }
-    if (ctx->has_glm_projector) {
+
+    if (ctx->has_llava_projector
+            || ctx->proj_type == PROJECTOR_TYPE_MINICPMV
+            || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
         GGML_ASSERT(batch_size == 1);
     }
 
@@ -2799,21 +2780,12 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
     ggml_backend_sched_alloc_graph(ctx->sched.get(), gf);
 
     // set inputs
-    const auto & model = ctx->vision_model;
+    const auto & model   = ctx->vision_model;
     const auto & hparams = model.hparams;
 
-    // TODO @ngxson : this is ugly, need to refactor later
-    bool support_dynamic_size = ctx->has_minicpmv_projector
-        || ctx->has_qwen2vl_merger
-        || ctx->proj_type == PROJECTOR_TYPE_PIXTRAL;
+    const int image_size_width  = imgs.entries[0]->nx;
+    const int image_size_height = imgs.entries[0]->ny;
 
-    const int image_size = hparams.image_size;
-    int image_size_width  = image_size;
-    int image_size_height = image_size;
-    if (support_dynamic_size) {
-        image_size_width  = imgs.entries[0]->nx;
-        image_size_height = imgs.entries[0]->ny;
-    }
     const int patch_size    = hparams.patch_size;
     const int num_patches   = ((image_size_width / patch_size) * (image_size_height / patch_size));
     const int num_positions = num_patches + (model.class_embedding ? 1 : 0);
@@ -2839,14 +2811,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
         for (size_t i = 0; i < imgs.entries.size(); i++) {
             const int nx = imgs.entries[i]->nx;
             const int ny = imgs.entries[i]->ny;
-
-            if (ctx->has_glm_projector
-                    || ctx->has_llava_projector
-                    || ctx->proj_type == PROJECTOR_TYPE_GEMMA3
-                    || ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
-                GGML_ASSERT(nx == image_size && ny == image_size);
-            }
-
             const int n = nx * ny;
 
             for (int b = 0; b < batch_size; b++) {
@@ -2864,13 +2828,15 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
         }
         ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw));
     }
-    if (ctx->has_minicpmv_projector) {
+
+    if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
         {
             // inspired from siglip:
             //    -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit
             //    -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316
             struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
-            int* positions_data = (int*)malloc(ggml_nbytes(positions));
+            std::vector<int> pos_data(ggml_nelements(positions));
+            int * data = pos_data.data();
             int bucket_coords_h[1024];
             int bucket_coords_w[1024];
             for (int i = 0; i < pos_h; i++){
@@ -2881,11 +2847,10 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
             }
             for (int i = 0, id = 0; i < pos_h; i++){
                 for (int j = 0; j < pos_w; j++){
-                    positions_data[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j];
+                    data[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j];
                 }
             }
-            ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
-            free(positions_data);
+            ggml_backend_tensor_set(positions, data, 0, ggml_nbytes(positions));
         }
 
         {
@@ -2903,30 +2868,28 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
             else if (ctx->minicpmv_version == 4) {
                 embed_dim = 3584;
             }
+            else {
+                GGML_ABORT("Unknown minicpmv version");
+            }
+
+            // TODO @ngxson : this is very inefficient, can we do this using ggml_sin and ggml_cos?
             auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h));
 
-            float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed));
-            for(int i=0;i < pos_w * pos_h; ++i){
-                for(int j=0; j < embed_dim; ++j){
-                    pos_embed_data[i * embed_dim + j] = pos_embed_t[i][j];
+            std::vector<float> pos_data(ggml_nelements(pos_embed));
+            float * data = pos_data.data();
+            for(int i = 0; i < pos_w * pos_h; ++i){
+                for(int j = 0; j < embed_dim; ++j){
+                    data[i * embed_dim + j] = pos_embed_t[i][j];
                 }
             }
 
-            ggml_backend_tensor_set(pos_embed, pos_embed_data, 0, ggml_nbytes(pos_embed));
-            free(pos_embed_data);
+            ggml_backend_tensor_set(pos_embed, data, 0, ggml_nbytes(pos_embed));
         }
     }
     else {
-        if (model.class_embedding) {
-            struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings");
+        // non-minicpmv models
 
-            void* zero_mem = malloc(ggml_nbytes(embeddings));
-            memset(zero_mem, 0, ggml_nbytes(embeddings));
-            ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings));
-            free(zero_mem);
-        }
-
-        if (ctx->has_qwen2vl_merger) {
+        if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
             struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
 
             const int pw = image_size_width / patch_size;
@@ -2978,6 +2941,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
             ggml_backend_tensor_set(pos, pos_data.data(), 0, ggml_nbytes(pos));
         }
         else {
+            // llava and other models
             struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
 
             int* positions_data = (int*)malloc(ggml_nbytes(positions));
@@ -2987,7 +2951,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
             ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
             free(positions_data);
 
-            if (!ctx->has_glm_projector) {
+            if (ctx->proj_type != PROJECTOR_TYPE_GLM_EDGE) {
                 struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches");
                 // The patches vector is used to get rows to index into the embeds with;
                 // we should skip dim 0 only if we have CLS to avoid going out of bounds
@@ -3166,7 +3130,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
             return ctx->vision_model.mm_2_b->ne[0];
         case PROJECTOR_TYPE_MLP_NORM:
             return ctx->vision_model.mm_3_b->ne[0];
-        case PROJECTOR_TYPE_RESAMPLER:
+        case PROJECTOR_TYPE_MINICPMV:
             if (ctx->minicpmv_version == 2) {
                 return 4096;
             } else if (ctx->minicpmv_version == 3) {
@@ -3174,36 +3138,33 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
             } else if (ctx->minicpmv_version == 4) {
                 return 3584;
             }
-            break; // Should not happen if version is valid
+            GGML_ABORT("Unknown minicpmv version");
         case PROJECTOR_TYPE_GLM_EDGE:
             return ctx->vision_model.mm_model_mlp_3_w->ne[1];
-        case PROJECTOR_TYPE_MERGER:
+        case PROJECTOR_TYPE_QWEN2VL:
             return ctx->vision_model.mm_1_b->ne[0];
         case PROJECTOR_TYPE_GEMMA3:
             return ctx->vision_model.mm_input_proj_w->ne[0];
         case PROJECTOR_TYPE_IDEFICS3:
             return ctx->vision_model.projection->ne[1];
         default:
-            break; // Fall through to throw
+            GGML_ABORT("Unknown projector type");
     }
-
-    std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type];
-    throw std::runtime_error(string_format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
 }
 
 int clip_is_minicpmv(const struct clip_ctx * ctx) {
-    if (ctx->has_minicpmv_projector) {
+    if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
         return ctx->minicpmv_version;
     }
     return 0;
 }
 
 bool clip_is_glm(const struct clip_ctx * ctx) {
-    return ctx->has_glm_projector;
+    return ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE;
 }
 
 bool clip_is_qwen2vl(const struct clip_ctx * ctx) {
-    return ctx->has_qwen2vl_merger;
+    return ctx->proj_type == PROJECTOR_TYPE_QWEN2VL;
 }
 
 bool clip_is_llava(const struct clip_ctx * ctx) {
@@ -3214,29 +3175,6 @@ bool clip_is_gemma3(const struct clip_ctx * ctx) {
     return ctx->proj_type == PROJECTOR_TYPE_GEMMA3;
 }
 
-// Determine the number of encoder layers to iterate over
-int get_deepest_feature_layer(const struct clip_ctx * ctx) {
-    // Get the index of the second to last layer; this is the
-    // default for models that have a llava projector
-    const auto & hparams = ctx->vision_model.hparams;
-    int n_layer = hparams.n_layer - 1;
-    int deepest_feature_layer = -1;
-
-    // Handle other projectors; incrementing here indicates that we
-    // should use the last encoder layer for the vision features.
-    if (ctx->has_minicpmv_projector || ctx->has_glm_projector || ctx->has_qwen2vl_merger) {
-        n_layer += 1;
-    }
-
-    // If we set explicit vision feature layers, only go up to the deepest one
-    for (const auto & feature_layer : hparams.vision_feature_layer) {
-        if (feature_layer > deepest_feature_layer) {
-            deepest_feature_layer = feature_layer;
-        }
-    }
-    return deepest_feature_layer < 0 ? n_layer : deepest_feature_layer;
-}
-
 bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
     clip_image_f32 clip_img;
     clip_img.buf.resize(h * w * 3);
index 5fc45d3e2390486f35085d0228b8c55bf647185d..6ba42ad8921468843f21e30b6359e95550cb2ce6 100644 (file)
@@ -114,8 +114,6 @@ CLIP_API bool clip_is_qwen2vl(const struct clip_ctx * ctx);
 CLIP_API bool clip_is_llava(const struct clip_ctx * ctx);
 CLIP_API bool clip_is_gemma3(const struct clip_ctx * ctx);
 
-CLIP_API int get_deepest_feature_layer(const struct clip_ctx * ctx);
-
 CLIP_API bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec);