]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
clip : refactor graph builder (#13321)
authorXuan-Son Nguyen <redacted>
Tue, 6 May 2025 20:40:24 +0000 (22:40 +0200)
committerGitHub <redacted>
Tue, 6 May 2025 20:40:24 +0000 (22:40 +0200)
* mtmd : refactor graph builder

* fix qwen2vl

* clean up siglip cgraph

* pixtral migrated

* move minicpmv to a dedicated build function

* move max_feature_layer to build_llava

* use build_attn for minicpm resampler

* fix windows build

* add comment for batch_size

* also support tinygemma3 test model

* qwen2vl does not use RMS norm

* fix qwen2vl norm (2)

convert_hf_to_gguf.py
tools/mtmd/clip.cpp

index de6d55cb082c1eb1c5afac3fab4f0f124736a9ef..a6aaf883464b2198e454b835c55be46820891f3c 100755 (executable)
@@ -3915,6 +3915,16 @@ class Gemma3VisionModel(VisionModel):
         # default values below are taken from HF tranformers code
         self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("layer_norm_eps", 1e-6))
         self.gguf_writer.add_vision_use_gelu(True)
+        # calculate proj_scale_factor (used by tinygemma3 test model)
+        image_seq_length = self.preprocessor_config.get("image_seq_length", 256)
+        n_per_side = int(image_seq_length ** 0.5)
+        image_size = self.hparams["image_size"]
+        patch_size = self.hparams["patch_size"]
+        proj_scale_factor = (image_size // patch_size) // n_per_side
+        if proj_scale_factor > 0 and proj_scale_factor != 4:
+            # we only need to write this if it's not the default value
+            # in this case, we are converting a test model
+            self.gguf_writer.add_vision_projector_scale_factor(proj_scale_factor)
 
     def tensor_force_quant(self, name, new_name, bid, n_dims):
         del bid, new_name, n_dims  # unused
@@ -3928,6 +3938,9 @@ class Gemma3VisionModel(VisionModel):
     def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
         del bid  # unused
 
+        if "vision_model.head." in name:
+            return [] # skip redundant tensors for tinygemma3
+
         if name.startswith("multi_modal_projector.") or name.startswith("vision_tower.") \
                 or name.startswith("multimodal_projector.") or name.startswith("vision_model."):
             # process vision tensors
index 1414f6ba4f5ef62a15a5cfc4963ca5530c5a20ad..4432fb7193d7de349b373888e33b15722a11a6e1 100644 (file)
 #include <limits>
 #include <array>
 #include <numeric>
+#include <functional>
 
 struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL};
 
+enum ffn_op_type {
+    FFN_GELU,
+    FFN_SILU,
+    FFN_GELU_QUICK,
+};
+
+enum norm_type {
+    NORM_TYPE_NORMAL,
+    NORM_TYPE_RMS,
+};
+
 //#define CLIP_DEBUG_FUNCTIONS
 
 #ifdef CLIP_DEBUG_FUNCTIONS
@@ -162,6 +174,8 @@ struct clip_hparams {
     int32_t n_layer;
     int32_t proj_scale_factor = 0; // idefics3
 
+    ffn_op_type ffn_op = FFN_GELU;
+
     patch_merge_type mm_patch_merge_type = PATCH_MERGE_FLAT;
 
     float eps = 1e-6;
@@ -177,136 +191,136 @@ struct clip_hparams {
 
 struct clip_layer {
     // attention
-    struct ggml_tensor * k_w = nullptr;
-    struct ggml_tensor * k_b = nullptr;
-    struct ggml_tensor * q_w = nullptr;
-    struct ggml_tensor * q_b = nullptr;
-    struct ggml_tensor * v_w = nullptr;
-    struct ggml_tensor * v_b = nullptr;
+    ggml_tensor * k_w = nullptr;
+    ggml_tensor * k_b = nullptr;
+    ggml_tensor * q_w = nullptr;
+    ggml_tensor * q_b = nullptr;
+    ggml_tensor * v_w = nullptr;
+    ggml_tensor * v_b = nullptr;
 
-    struct ggml_tensor * o_w = nullptr;
-    struct ggml_tensor * o_b = nullptr;
+    ggml_tensor * o_w = nullptr;
+    ggml_tensor * o_b = nullptr;
 
     // layernorm 1
-    struct ggml_tensor * ln_1_w = nullptr;
-    struct ggml_tensor * ln_1_b = nullptr;
+    ggml_tensor * ln_1_w = nullptr;
+    ggml_tensor * ln_1_b = nullptr;
 
-    struct ggml_tensor * ff_up_w = nullptr;
-    struct ggml_tensor * ff_up_b = nullptr;
-    struct ggml_tensor * ff_gate_w = nullptr;
-    struct ggml_tensor * ff_gate_b = nullptr;
-    struct ggml_tensor * ff_down_w = nullptr;
-    struct ggml_tensor * ff_down_b = nullptr;
+    ggml_tensor * ff_up_w = nullptr;
+    ggml_tensor * ff_up_b = nullptr;
+    ggml_tensor * ff_gate_w = nullptr;
+    ggml_tensor * ff_gate_b = nullptr;
+    ggml_tensor * ff_down_w = nullptr;
+    ggml_tensor * ff_down_b = nullptr;
 
     // layernorm 2
-    struct ggml_tensor * ln_2_w = nullptr;
-    struct ggml_tensor * ln_2_b = nullptr;
+    ggml_tensor * ln_2_w = nullptr;
+    ggml_tensor * ln_2_b = nullptr;
 };
 
 struct clip_vision_model {
     struct clip_hparams hparams;
 
     // embeddings
-    struct ggml_tensor * class_embedding = nullptr;
-    struct ggml_tensor * patch_embeddings_0 = nullptr;
-    struct ggml_tensor * patch_embeddings_1 = nullptr;  // second Conv2D kernel when we decouple Conv3D along temproal dimension (Qwen2VL)
-    struct ggml_tensor * patch_bias = nullptr;
-    struct ggml_tensor * position_embeddings = nullptr;
+    ggml_tensor * class_embedding = nullptr;
+    ggml_tensor * patch_embeddings_0 = nullptr;
+    ggml_tensor * patch_embeddings_1 = nullptr;  // second Conv2D kernel when we decouple Conv3D along temproal dimension (Qwen2VL)
+    ggml_tensor * patch_bias = nullptr;
+    ggml_tensor * position_embeddings = nullptr;
 
-    struct ggml_tensor * pre_ln_w = nullptr;
-    struct ggml_tensor * pre_ln_b = nullptr;
+    ggml_tensor * pre_ln_w = nullptr;
+    ggml_tensor * pre_ln_b = nullptr;
 
     std::vector<clip_layer> layers;
 
-    struct ggml_tensor * post_ln_w;
-    struct ggml_tensor * post_ln_b;
+    ggml_tensor * post_ln_w;
+    ggml_tensor * post_ln_b;
 
-    struct ggml_tensor * projection;
+    ggml_tensor * projection;
 
     // LLaVA projection
-    struct ggml_tensor * mm_input_norm_w = nullptr;
-    struct ggml_tensor * mm_0_w = nullptr;
-    struct ggml_tensor * mm_0_b = nullptr;
-    struct ggml_tensor * mm_2_w = nullptr;
-    struct ggml_tensor * mm_2_b = nullptr;
+    ggml_tensor * mm_input_norm_w = nullptr;
+    ggml_tensor * mm_0_w = nullptr;
+    ggml_tensor * mm_0_b = nullptr;
+    ggml_tensor * mm_2_w = nullptr;
+    ggml_tensor * mm_2_b = nullptr;
 
-    struct ggml_tensor * image_newline = nullptr;
+    ggml_tensor * image_newline = nullptr;
 
     // Yi type models with mlp+normalization projection
-    struct ggml_tensor * mm_1_w = nullptr; // Yi type models have 0, 1, 3, 4
-    struct ggml_tensor * mm_1_b = nullptr;
-    struct ggml_tensor * mm_3_w = nullptr;
-    struct ggml_tensor * mm_3_b = nullptr;
-    struct ggml_tensor * mm_4_w = nullptr;
-    struct ggml_tensor * mm_4_b = nullptr;
+    ggml_tensor * mm_1_w = nullptr; // Yi type models have 0, 1, 3, 4
+    ggml_tensor * mm_1_b = nullptr;
+    ggml_tensor * mm_3_w = nullptr;
+    ggml_tensor * mm_3_b = nullptr;
+    ggml_tensor * mm_4_w = nullptr;
+    ggml_tensor * mm_4_b = nullptr;
 
     // GLMV-Edge projection
-    struct ggml_tensor * mm_model_adapter_conv_w = nullptr;
-    struct ggml_tensor * mm_model_adapter_conv_b = nullptr;
-    struct ggml_tensor * mm_glm_tok_boi = nullptr;
-    struct ggml_tensor * mm_glm_tok_eoi = nullptr;
+    ggml_tensor * mm_model_adapter_conv_w = nullptr;
+    ggml_tensor * mm_model_adapter_conv_b = nullptr;
+    ggml_tensor * mm_glm_tok_boi = nullptr;
+    ggml_tensor * mm_glm_tok_eoi = nullptr;
 
     // MobileVLM projection
-    struct ggml_tensor * mm_model_mlp_1_w = nullptr;
-    struct ggml_tensor * mm_model_mlp_1_b = nullptr;
-    struct ggml_tensor * mm_model_mlp_3_w = nullptr;
-    struct ggml_tensor * mm_model_mlp_3_b = nullptr;
-    struct ggml_tensor * mm_model_block_1_block_0_0_w = nullptr;
-    struct ggml_tensor * mm_model_block_1_block_0_1_w = nullptr;
-    struct ggml_tensor * mm_model_block_1_block_0_1_b = nullptr;
-    struct ggml_tensor * mm_model_block_1_block_1_fc1_w = nullptr;
-    struct ggml_tensor * mm_model_block_1_block_1_fc1_b = nullptr;
-    struct ggml_tensor * mm_model_block_1_block_1_fc2_w = nullptr;
-    struct ggml_tensor * mm_model_block_1_block_1_fc2_b = nullptr;
-    struct ggml_tensor * mm_model_block_1_block_2_0_w = nullptr;
-    struct ggml_tensor * mm_model_block_1_block_2_1_w = nullptr;
-    struct ggml_tensor * mm_model_block_1_block_2_1_b = nullptr;
-    struct ggml_tensor * mm_model_block_2_block_0_0_w = nullptr;
-    struct ggml_tensor * mm_model_block_2_block_0_1_w = nullptr;
-    struct ggml_tensor * mm_model_block_2_block_0_1_b = nullptr;
-    struct ggml_tensor * mm_model_block_2_block_1_fc1_w = nullptr;
-    struct ggml_tensor * mm_model_block_2_block_1_fc1_b = nullptr;
-    struct ggml_tensor * mm_model_block_2_block_1_fc2_w = nullptr;
-    struct ggml_tensor * mm_model_block_2_block_1_fc2_b = nullptr;
-    struct ggml_tensor * mm_model_block_2_block_2_0_w = nullptr;
-    struct ggml_tensor * mm_model_block_2_block_2_1_w = nullptr;
-    struct ggml_tensor * mm_model_block_2_block_2_1_b = nullptr;
+    ggml_tensor * mm_model_mlp_1_w = nullptr;
+    ggml_tensor * mm_model_mlp_1_b = nullptr;
+    ggml_tensor * mm_model_mlp_3_w = nullptr;
+    ggml_tensor * mm_model_mlp_3_b = nullptr;
+    ggml_tensor * mm_model_block_1_block_0_0_w = nullptr;
+    ggml_tensor * mm_model_block_1_block_0_1_w = nullptr;
+    ggml_tensor * mm_model_block_1_block_0_1_b = nullptr;
+    ggml_tensor * mm_model_block_1_block_1_fc1_w = nullptr;
+    ggml_tensor * mm_model_block_1_block_1_fc1_b = nullptr;
+    ggml_tensor * mm_model_block_1_block_1_fc2_w = nullptr;
+    ggml_tensor * mm_model_block_1_block_1_fc2_b = nullptr;
+    ggml_tensor * mm_model_block_1_block_2_0_w = nullptr;
+    ggml_tensor * mm_model_block_1_block_2_1_w = nullptr;
+    ggml_tensor * mm_model_block_1_block_2_1_b = nullptr;
+    ggml_tensor * mm_model_block_2_block_0_0_w = nullptr;
+    ggml_tensor * mm_model_block_2_block_0_1_w = nullptr;
+    ggml_tensor * mm_model_block_2_block_0_1_b = nullptr;
+    ggml_tensor * mm_model_block_2_block_1_fc1_w = nullptr;
+    ggml_tensor * mm_model_block_2_block_1_fc1_b = nullptr;
+    ggml_tensor * mm_model_block_2_block_1_fc2_w = nullptr;
+    ggml_tensor * mm_model_block_2_block_1_fc2_b = nullptr;
+    ggml_tensor * mm_model_block_2_block_2_0_w = nullptr;
+    ggml_tensor * mm_model_block_2_block_2_1_w = nullptr;
+    ggml_tensor * mm_model_block_2_block_2_1_b = nullptr;
 
     // MobileVLM_V2 projection
-    struct ggml_tensor * mm_model_mlp_0_w = nullptr;
-    struct ggml_tensor * mm_model_mlp_0_b = nullptr;
-    struct ggml_tensor * mm_model_mlp_2_w = nullptr;
-    struct ggml_tensor * mm_model_mlp_2_b = nullptr;
-    struct ggml_tensor * mm_model_peg_0_w = nullptr;
-    struct ggml_tensor * mm_model_peg_0_b = nullptr;
+    ggml_tensor * mm_model_mlp_0_w = nullptr;
+    ggml_tensor * mm_model_mlp_0_b = nullptr;
+    ggml_tensor * mm_model_mlp_2_w = nullptr;
+    ggml_tensor * mm_model_mlp_2_b = nullptr;
+    ggml_tensor * mm_model_peg_0_w = nullptr;
+    ggml_tensor * mm_model_peg_0_b = nullptr;
 
     // MINICPMV projection
-    struct ggml_tensor * mm_model_pos_embed_k = nullptr;
-    struct ggml_tensor * mm_model_query = nullptr;
-    struct ggml_tensor * mm_model_proj = nullptr;
-    struct ggml_tensor * mm_model_kv_proj = nullptr;
-    struct ggml_tensor * mm_model_attn_q_w = nullptr;
-    struct ggml_tensor * mm_model_attn_q_b = nullptr;
-    struct ggml_tensor * mm_model_attn_k_w = nullptr;
-    struct ggml_tensor * mm_model_attn_k_b = nullptr;
-    struct ggml_tensor * mm_model_attn_v_w = nullptr;
-    struct ggml_tensor * mm_model_attn_v_b = nullptr;
-    struct ggml_tensor * mm_model_attn_o_w = nullptr;
-    struct ggml_tensor * mm_model_attn_o_b = nullptr;
-    struct ggml_tensor * mm_model_ln_q_w = nullptr;
-    struct ggml_tensor * mm_model_ln_q_b = nullptr;
-    struct ggml_tensor * mm_model_ln_kv_w = nullptr;
-    struct ggml_tensor * mm_model_ln_kv_b = nullptr;
-    struct ggml_tensor * mm_model_ln_post_w = nullptr;
-    struct ggml_tensor * mm_model_ln_post_b = nullptr;
+    ggml_tensor * mm_model_pos_embed_k = nullptr;
+    ggml_tensor * mm_model_query = nullptr;
+    ggml_tensor * mm_model_proj = nullptr;
+    ggml_tensor * mm_model_kv_proj = nullptr;
+    ggml_tensor * mm_model_attn_q_w = nullptr;
+    ggml_tensor * mm_model_attn_q_b = nullptr;
+    ggml_tensor * mm_model_attn_k_w = nullptr;
+    ggml_tensor * mm_model_attn_k_b = nullptr;
+    ggml_tensor * mm_model_attn_v_w = nullptr;
+    ggml_tensor * mm_model_attn_v_b = nullptr;
+    ggml_tensor * mm_model_attn_o_w = nullptr;
+    ggml_tensor * mm_model_attn_o_b = nullptr;
+    ggml_tensor * mm_model_ln_q_w = nullptr;
+    ggml_tensor * mm_model_ln_q_b = nullptr;
+    ggml_tensor * mm_model_ln_kv_w = nullptr;
+    ggml_tensor * mm_model_ln_kv_b = nullptr;
+    ggml_tensor * mm_model_ln_post_w = nullptr;
+    ggml_tensor * mm_model_ln_post_b = nullptr;
 
     // gemma3
-    struct ggml_tensor * mm_input_proj_w = nullptr;
-    struct ggml_tensor * mm_soft_emb_norm_w = nullptr;
+    ggml_tensor * mm_input_proj_w = nullptr;
+    ggml_tensor * mm_soft_emb_norm_w = nullptr;
 
     // pixtral
-    struct ggml_tensor * token_embd_img_break = nullptr;
-    struct ggml_tensor * mm_patch_merger_w = nullptr;
+    ggml_tensor * token_embd_img_break = nullptr;
+    ggml_tensor * mm_patch_merger_w = nullptr;
 };
 
 struct clip_ctx {
@@ -316,11 +330,8 @@ struct clip_ctx {
     struct clip_vision_model vision_model;
     projector_type proj_type = PROJECTOR_TYPE_MLP;
 
-    int32_t max_feature_layer; // unused in newer models like gemma3
     float image_mean[3];
     float image_std[3];
-    bool use_gelu = false;
-    bool use_silu = false;
 
     gguf_context_ptr ctx_gguf;
     ggml_context_ptr ctx_data;
@@ -370,1239 +381,1252 @@ struct clip_ctx {
     }
 };
 
-static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_image_f32 & img) {
-    const auto & model = ctx->vision_model;
-    const auto & hparams = model.hparams;
-
-    int image_size_width  = img.nx;
-    int image_size_height = img.ny;
-
-    const int patch_size  = hparams.patch_size;
-    const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
-    const int n_embd      = hparams.n_embd;
-    const int n_head      = hparams.n_head;
-    const int d_head      = n_embd / n_head;
-    const int n_layer     = hparams.n_layer;
-    const float eps       = hparams.eps;
-
-    struct ggml_init_params params = {
-        /*.mem_size   =*/ ctx->buf_compute_meta.size(),
-        /*.mem_buffer =*/ ctx->buf_compute_meta.data(),
-        /*.no_alloc   =*/ true,
-    };
-
-    ggml_context_ptr ctx0_ptr(ggml_init(params));
-    auto ctx0 = ctx0_ptr.get();
-
-    struct ggml_cgraph * gf = ggml_new_graph(ctx0);
+struct clip_graph {
+    clip_ctx * ctx;
+    const clip_vision_model & model;
+    const clip_hparams & hparams;
+
+    // we only support single image per batch
+    const clip_image_f32 & img;
+
+    const int patch_size;
+    const int n_patches_x;
+    const int n_patches_y;
+    const int n_patches;
+    const int n_embd;
+    const int n_head;
+    const int d_head;
+    const int n_layer;
+    const float eps;
+    const float kq_scale;
+
+    ggml_context_ptr ctx0_ptr;
+    ggml_context * ctx0;
+    ggml_cgraph * gf;
+
+    clip_graph(clip_ctx * ctx, const clip_image_f32 & img) :
+            ctx(ctx),
+            model(ctx->vision_model),
+            hparams(model.hparams),
+            img(img),
+            patch_size(hparams.patch_size),
+            n_patches_x(img.nx / patch_size),
+            n_patches_y(img.ny / patch_size),
+            n_patches(n_patches_x * n_patches_y),
+            n_embd(hparams.n_embd),
+            n_head(hparams.n_head),
+            d_head(n_embd / n_head),
+            n_layer(hparams.n_layer),
+            eps(hparams.eps),
+            kq_scale(1.0f / sqrtf((float)d_head)) {
+        struct ggml_init_params params = {
+            /*.mem_size   =*/ ctx->buf_compute_meta.size(),
+            /*.mem_buffer =*/ ctx->buf_compute_meta.data(),
+            /*.no_alloc   =*/ true,
+        };
+        ctx0_ptr.reset(ggml_init(params));
+        ctx0 = ctx0_ptr.get();
+        gf = ggml_new_graph(ctx0);
+    }
+
+    ggml_cgraph * build_siglip() {
+        ggml_tensor * inp = build_inp();
+        ggml_tensor * cur = build_vit(
+                                inp, n_patches,
+                                NORM_TYPE_NORMAL,
+                                hparams.ffn_op,
+                                model.position_embeddings,
+                                nullptr);
+
+        if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
+            const int batch_size = 1;
+            GGML_ASSERT(n_patches_x == n_patches_y);
+            const int patches_per_image = n_patches_x;
+            const int kernel_size = hparams.proj_scale_factor;
+
+            cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
+            cur = ggml_reshape_4d(ctx0, cur, patches_per_image, patches_per_image, n_embd, batch_size);
+
+            // doing a pool2d to reduce the number of output tokens
+            cur = ggml_pool_2d(ctx0, cur, GGML_OP_POOL_AVG, kernel_size, kernel_size, kernel_size, kernel_size, 0, 0);
+            cur = ggml_reshape_3d(ctx0, cur, cur->ne[0] * cur->ne[0], n_embd, batch_size);
+            cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
+
+            // apply norm before projection
+            cur = ggml_rms_norm(ctx0, cur, eps);
+            cur = ggml_mul(ctx0, cur, model.mm_soft_emb_norm_w);
+
+            // apply projection
+            cur = ggml_mul_mat(ctx0,
+                ggml_cont(ctx0, ggml_transpose(ctx0, model.mm_input_proj_w)),
+                cur);
+
+        } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
+            // https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578
+
+            const int scale_factor = model.hparams.proj_scale_factor;
+            const int n_embd = cur->ne[0];
+            const int seq    = cur->ne[1];
+            const int bsz    = 1; // batch size, always 1 for now since we don't support batching
+            const int height = std::sqrt(seq);
+            const int width  = std::sqrt(seq);
+            GGML_ASSERT(scale_factor != 0);
+            cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height, bsz);
+            cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
+            cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur),
+                n_embd * scale_factor * scale_factor,
+                height / scale_factor,
+                width / scale_factor,
+                bsz);
+            cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
+            cur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, cur),
+                n_embd * scale_factor * scale_factor,
+                seq / (scale_factor * scale_factor),
+                bsz);
+
+            cur = ggml_mul_mat(ctx0, model.projection, cur);
+        } else {
+            GGML_ABORT("SigLIP: Unsupported projector type");
+        }
 
-    // input raw
-    struct ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3);
-    ggml_set_name(inp_raw, "inp_raw");
-    ggml_set_input(inp_raw);
+        // build the graph
+        ggml_build_forward_expand(gf, cur);
 
-    struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
-    inp = ggml_reshape_2d(ctx0, inp, num_patches, n_embd);
-    inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
-    inp = ggml_add(ctx0, inp, model.patch_bias);
+        return gf;
+    }
 
-    // position embeddings
-    struct ggml_tensor * embeddings = ggml_add(ctx0, inp, model.position_embeddings);
+    ggml_cgraph * build_pixtral() {
+        const int n_merge = hparams.spatial_merge_size;
 
-    // loop over layers
-    for (int il = 0; il < n_layer; il++) {
-        struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
+        // 2D input positions
+        ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
+        ggml_set_name(pos_h, "pos_h");
+        ggml_set_input(pos_h);
 
-        // layernorm1
-        {
-            cur = ggml_norm(ctx0, cur, eps);
-            cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w), model.layers[il].ln_1_b);
-        }
+        ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
+        ggml_set_name(pos_w, "pos_w");
+        ggml_set_input(pos_w);
 
-        // self-attention
-        {
-
-            struct ggml_tensor * Q =
-                ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b);
+        auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
+            return build_rope_2d(ctx0, cur, pos_h, pos_w, hparams.rope_theta);
+        };
 
-            Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_patches);
-            Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
+        ggml_tensor * inp = build_inp();
+        ggml_tensor * cur = build_vit(
+                                inp, n_patches,
+                                NORM_TYPE_RMS,
+                                hparams.ffn_op,
+                                nullptr, // no learned pos embd
+                                add_pos);
 
-            struct ggml_tensor * K =
-                ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b);
+        // mistral small 3.1 patch merger
+        // ref: https://github.com/huggingface/transformers/blob/7a3e208892c06a5e278144eaf38c8599a42f53e7/src/transformers/models/mistral3/modeling_mistral3.py#L67
+        if (model.mm_patch_merger_w) {
+            GGML_ASSERT(hparams.spatial_merge_size > 0);
 
-            K = ggml_reshape_3d(ctx0, K, d_head, n_head, num_patches);
-            K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
+            cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.mm_input_norm_w);
 
-            struct ggml_tensor * V =
-                ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b);
+            // reshape image tokens to 2D grid
+            cur = ggml_reshape_3d(ctx0, cur, n_embd, n_patches_x, n_patches_y);
+            cur = ggml_permute(ctx0, cur, 2, 0, 1, 3); // [x, y, n_embd]
+            cur = ggml_cont(ctx0, cur);
 
-            V = ggml_reshape_3d(ctx0, V, d_head, n_head, num_patches);
-            V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
+            // torch.nn.functional.unfold is just an im2col under the hood
+            // we just need a dummy kernel to make it work
+            ggml_tensor * kernel = ggml_view_3d(ctx0, cur, n_merge, n_merge, cur->ne[2], 0, 0, 0);
+            cur = ggml_im2col(ctx0, kernel, cur, n_merge, n_merge, 0, 0, 1, 1, true, inp->type);
 
-            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
-            KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
+            // project to n_embd
+            cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]);
+            cur = ggml_mul_mat(ctx0, model.mm_patch_merger_w, cur);
+        }
 
-            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
-            KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head);
-            KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+        // LlavaMultiModalProjector (always using GELU activation)
+        {
+            cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
+            if (model.mm_1_b) {
+                cur = ggml_add(ctx0, cur, model.mm_1_b);
+            }
 
-            cur = ggml_cont_2d(ctx0, KQV, n_embd, num_patches);
+            cur = ggml_gelu(ctx0, cur);
+            cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
+            if (model.mm_2_b) {
+                cur = ggml_add(ctx0, cur, model.mm_2_b);
+            }
         }
 
-        // attention output
-        cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].o_w, cur), model.layers[il].o_b);
-
-        // re-add the layer input, e.g., residual
-        cur = ggml_add(ctx0, cur, embeddings);
+        // arrangement of the [IMG_BREAK] token
+        {
+            // not efficient, but works
+            // the trick is to view the embeddings as a 3D tensor with shape [n_embd, n_patches_per_row, n_rows]
+            // and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension
+            // after the concatenation, we have a tensor with shape [n_embd, n_patches_per_row + 1, n_rows]
 
-        embeddings = cur; // embeddings = residual, cur = hidden_states
+            const int p_y             = n_merge > 0 ? n_patches_y / n_merge : n_patches_y;
+            const int p_x             = n_merge > 0 ? n_patches_x / n_merge : n_patches_x;
+            const int p_total         = p_x * p_y;
+            const int n_embd_text     = cur->ne[0];
+            const int n_tokens_output = p_total + p_y - 1; // one [IMG_BREAK] per row, except the last row
 
-        // layernorm2
-        {
-            cur = ggml_norm(ctx0, cur, eps);
-            cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b);
+            ggml_tensor * tmp = ggml_reshape_3d(ctx0, cur, n_embd_text, p_x, p_y);
+            ggml_tensor * tok = ggml_new_tensor_3d(ctx0, tmp->type, n_embd_text, 1, p_y);
+            tok = ggml_scale(ctx0, tok, 0.0); // clear the tensor
+            tok = ggml_add(ctx0, tok, model.token_embd_img_break);
+            tmp = ggml_concat(ctx0, tmp, tok, 1);
+            cur = ggml_view_2d(ctx0, tmp,
+                n_embd_text, n_tokens_output,
+                ggml_row_size(tmp->type, n_embd_text), 0);
         }
 
-        cur = ggml_mul_mat(ctx0, model.layers[il].ff_up_w, cur);
-        cur = ggml_add(ctx0, cur, model.layers[il].ff_up_b);
+        // build the graph
+        ggml_build_forward_expand(gf, cur);
 
-        // siglip uses gelu
-        cur = ggml_gelu(ctx0, cur);
+        return gf;
+    }
 
-        cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur);
-        cur = ggml_add(ctx0, cur, model.layers[il].ff_down_b);
+    // Qwen2VL and Qwen2.5VL use M-RoPE
+    ggml_cgraph * build_qwen2vl() {
+        const int batch_size       = 1;
+        const bool use_window_attn = hparams.n_wa_pattern > 0;
+        const int n_wa_pattern     = hparams.n_wa_pattern;
+        const int n_pos            = n_patches;
+        const int num_position_ids = n_pos * 4; // m-rope requires 4 dim per position
 
-        // residual 2
-        cur = ggml_add(ctx0, embeddings, cur);
+        norm_type norm_t = ctx->proj_type == PROJECTOR_TYPE_QWEN25VL
+            ? NORM_TYPE_RMS // qwen 2.5 vl
+            : NORM_TYPE_NORMAL; // qwen 2 vl
 
-        embeddings = cur;
-    }
+        int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
 
-    // post-layernorm
-    if (model.post_ln_w) {
-        embeddings = ggml_norm(ctx0, embeddings, eps);
-        ggml_set_name(embeddings, "post_ln");
+        ggml_tensor * inp_raw = build_inp_raw();
+        ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
 
-        embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b);
-    }
+        GGML_ASSERT(img.nx % (patch_size * 2) == 0);
+        GGML_ASSERT(img.ny % (patch_size * 2) == 0);
 
-    if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
-        const int batch_size = 1;
-        const int mm_tokens_per_image = 256; // default value for gemma3
-        const int tokens_per_side = sqrt(mm_tokens_per_image);
-        const int patches_per_image = sqrt(num_patches);
-        const int kernel_size = patches_per_image / tokens_per_side;
+        // second conv dimension
+        {
+            auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
+            inp = ggml_add(ctx0, inp, inp_1);
+
+            inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3));  // [w, h, c, b] -> [c, w, h, b]
+            inp = ggml_reshape_4d(
+                ctx0, inp,
+                n_embd * 2, n_patches_x / 2, n_patches_y, batch_size);
+            inp = ggml_reshape_4d(
+                ctx0, inp,
+                n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2));
+            inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3));
+            inp = ggml_reshape_3d(
+                ctx0, inp,
+                n_embd, n_patches_x * n_patches_y, batch_size);
+        }
+
+        if (model.patch_bias) {
+            inp = ggml_add(ctx0, inp, model.patch_bias);
+        }
+
+        ggml_tensor * inpL           = inp;
+        ggml_tensor * window_mask    = nullptr;
+        ggml_tensor * window_idx     = nullptr;
+        ggml_tensor * inv_window_idx = nullptr;
+
+        ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids);
+        ggml_set_name(positions, "positions");
+        ggml_set_input(positions);
+
+        // pre-layernorm
+        if (model.pre_ln_w) {
+            inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1);
+        }
+
+        if (use_window_attn) {
+            // handle window attention inputs
+            inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / 4);
+            ggml_set_name(inv_window_idx, "inv_window_idx");
+            ggml_set_input(inv_window_idx);
+            // mask for window attention
+            window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_pos, n_pos);
+            ggml_set_name(window_mask, "window_mask");
+            ggml_set_input(window_mask);
+
+            // inpL shape: [n_embd, n_patches_x * n_patches_y, batch_size]
+            GGML_ASSERT(batch_size == 1);
+            inpL = ggml_reshape_2d(ctx0, inpL, n_embd * 4, n_patches_x * n_patches_y * batch_size / 4);
+            inpL = ggml_get_rows(ctx0, inpL, inv_window_idx);
+            inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_patches_x * n_patches_y, batch_size);
+        }
+
+        // loop over layers
+        for (int il = 0; il < n_layer; il++) {
+            auto & layer = model.layers[il];
+            const bool full_attn = use_window_attn ? (il + 1) % n_wa_pattern == 0 : true;
 
-        embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings));
-        embeddings = ggml_reshape_4d(ctx0, embeddings, patches_per_image, patches_per_image, n_embd, batch_size);
+            ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states
 
-        // doing a pool2d to reduce the number of output tokens to 256
-        embeddings = ggml_pool_2d(ctx0, embeddings, GGML_OP_POOL_AVG, kernel_size, kernel_size, kernel_size, kernel_size, 0, 0);
-        embeddings = ggml_reshape_3d(ctx0, embeddings, embeddings->ne[0] * embeddings->ne[0], n_embd, batch_size);
-        embeddings = ggml_cont(ctx0, ggml_transpose(ctx0, embeddings));
+            // layernorm1
+            cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, norm_t, eps, il);
+            cb(cur, "ln1", il);
 
-        // apply norm before projection
-        embeddings = ggml_rms_norm(ctx0, embeddings, eps);
-        embeddings = ggml_mul(ctx0, embeddings, model.mm_soft_emb_norm_w);
+            // self-attention
+            {
+                ggml_tensor * Qcur = ggml_add(ctx0,
+                    ggml_mul_mat(ctx0, layer.q_w, cur), layer.q_b);
+                ggml_tensor * Kcur = ggml_add(ctx0,
+                    ggml_mul_mat(ctx0, layer.k_w, cur), layer.k_b);
+                ggml_tensor * Vcur = ggml_add(ctx0,
+                    ggml_mul_mat(ctx0, layer.v_w, cur), layer.v_b);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_patches);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_patches);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_patches);
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                // apply M-RoPE
+                Qcur = ggml_rope_multi(
+                    ctx0, Qcur, positions, nullptr,
+                    d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
+                Kcur = ggml_rope_multi(
+                    ctx0, Kcur, positions, nullptr,
+                    d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
 
-        // apply projection
-        embeddings = ggml_mul_mat(ctx0,
-            ggml_cont(ctx0, ggml_transpose(ctx0, model.mm_input_proj_w)),
-            embeddings);
+                cb(Qcur, "Qcur_rope", il);
+                cb(Kcur, "Kcur_rope", il);
 
-    } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
-        // https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578
-
-        ggml_tensor * cur = embeddings;
-        const int scale_factor = model.hparams.proj_scale_factor;
-        const int n_embd = cur->ne[0];
-        const int seq    = cur->ne[1];
-        const int bsz    = 1; // batch size, always 1 for now since we don't support batching
-        const int height = std::sqrt(seq);
-        const int width  = std::sqrt(seq);
-        GGML_ASSERT(scale_factor != 0);
-        cur = ggml_reshape_4d(ctx0, cur, n_embd * scale_factor, width / scale_factor, height, bsz);
-        cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
-        cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur),
-            n_embd * scale_factor * scale_factor,
-            height / scale_factor,
-            width / scale_factor,
-            bsz);
-        cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
-        cur = ggml_reshape_3d(ctx0, ggml_cont(ctx0, cur),
-            n_embd * scale_factor * scale_factor,
-            seq / (scale_factor * scale_factor),
-            bsz);
-
-        cur = ggml_mul_mat(ctx0, model.projection, cur);
-        embeddings = cur;
-    } else {
-        GGML_ABORT("SigLIP: Unsupported projector type");
-    }
-
-    // build the graph
-    ggml_build_forward_expand(gf, embeddings);
-
-    return gf;
-}
+                ggml_tensor * attn_mask = full_attn ? nullptr : window_mask;
 
-// implementation of the 2D RoPE without adding a new op in ggml
-// this is not efficient (use double the memory), but works on all backends
-// TODO: there was a more efficient which relies on ggml_view and ggml_rope_ext_inplace, but the rope inplace does not work well with non-contiguous tensors ; we should fix that and revert back to the original implementation in https://github.com/ggml-org/llama.cpp/pull/13065
-static ggml_tensor * build_rope_2d(
-    ggml_context * ctx0,
-    ggml_tensor * cur,
-    ggml_tensor * pos_h,
-    ggml_tensor * pos_w,
-    const float freq_base
-) {
-    const int64_t n_dim  = cur->ne[0];
-    const int64_t n_head = cur->ne[1];
-    const int64_t n_pos  = cur->ne[2];
-
-    // for example, if we have cur tensor of shape (n_dim=8, n_head, n_pos)
-    // we will have a list of 4 inv_freq: 1e-0, 1e-1, 1e-2, 1e-3
-    // first half of cur will use 1e-0, 1e-2 (even)
-    // second half of cur will use 1e-1, 1e-3 (odd)
-    // the trick here is to rotate just half of n_dim, so inv_freq will automatically be even
-    //  ^ don't ask me why, it's math! -2(2i) / n_dim == -2i / (n_dim/2)
-    // then for the second half, we use freq_scale to shift the inv_freq
-    //  ^ why? replace (2i) with (2i+1) in the above equation
-    const float freq_scale_odd = std::pow(freq_base, (float)-2/n_dim);
-
-    // first half
-    ggml_tensor * first;
-    {
-        first = ggml_view_3d(ctx0, cur,
-            n_dim/2, n_head, n_pos,
-            ggml_row_size(cur->type, n_dim),
-            ggml_row_size(cur->type, n_dim*n_head),
-            0);
-        first = ggml_rope_ext(
-            ctx0,
-            first,
-            pos_h,      // positions
-            nullptr,    // freq factors
-            n_dim/2,    // n_dims
-            0, 0, freq_base,
-            1.0f, 0.0f, 1.0f, 0.0f, 0.0f
-        );
-    }
+                cur = build_attn(layer.o_w, layer.o_b,
+                    Qcur, Kcur, Vcur, attn_mask, kq_scale, il);
+                cb(cur, "attn_out", il);
+            }
 
-    // second half
-    ggml_tensor * second;
-    {
-        second = ggml_view_3d(ctx0, cur,
-            n_dim/2, n_head, n_pos,
-            ggml_row_size(cur->type, n_dim),
-            ggml_row_size(cur->type, n_dim*n_head),
-            n_dim/2 * ggml_element_size(cur));
-        second = ggml_cont(ctx0, second); // copy, because ggml_rope don't play well with non-contiguous tensors
-        second = ggml_rope_ext(
-            ctx0,
-            second,
-            pos_w,      // positions
-            nullptr,    // freq factors
-            n_dim/2,    // n_dims
-            0, 0, freq_base,
-            freq_scale_odd,
-            0.0f, 1.0f, 0.0f, 0.0f
-        );
-    }
+            // re-add the layer input, e.g., residual
+            cur = ggml_add(ctx0, cur, inpL);
 
-    cur = ggml_concat(ctx0, first, second, 0);
-    return cur;
-}
+            inpL = cur; // inpL = residual, cur = hidden_states
 
-static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_image_f32 & img) {
-    const auto & model = ctx->vision_model;
-    const auto & hparams = model.hparams;
+            cb(cur, "ffn_inp", il);
 
-    GGML_ASSERT(ctx->proj_type == PROJECTOR_TYPE_PIXTRAL);
-
-    int image_size_width  = img.nx;
-    int image_size_height = img.ny;
-
-    const int patch_size  = hparams.patch_size;
-    const int n_patches_x = image_size_width  / patch_size;
-    const int n_patches_y = image_size_height / patch_size;
-    const int num_patches = n_patches_x * n_patches_y;
-    const int n_embd      = hparams.n_embd;
-    const int n_head      = hparams.n_head;
-    const int d_head      = n_embd / n_head;
-    const int n_layer     = hparams.n_layer;
-    const float eps       = hparams.eps;
-    const int n_merge     = hparams.spatial_merge_size;
-
-    struct ggml_init_params params = {
-        /*.mem_size   =*/ ctx->buf_compute_meta.size(),
-        /*.mem_buffer =*/ ctx->buf_compute_meta.data(),
-        /*.no_alloc   =*/ true,
-    };
+            // layernorm2
+            cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, norm_t, eps, il);
+            cb(cur, "ffn_inp_normed", il);
 
-    ggml_context_ptr ctx0_ptr(ggml_init(params));
-    auto ctx0 = ctx0_ptr.get();
+            // ffn
+            cur = build_ffn(cur,
+                layer.ff_up_w, layer.ff_up_b,
+                layer.ff_gate_w, layer.ff_gate_b,
+                layer.ff_down_w, layer.ff_down_b,
+                hparams.ffn_op, il);
 
-    struct ggml_cgraph * gf = ggml_new_graph(ctx0);
+            cb(cur, "ffn_out", il);
 
-    // input raw
-    struct ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3);
-    ggml_set_name(inp_raw, "inp_raw");
-    ggml_set_input(inp_raw);
+            // residual 2
+            cur = ggml_add(ctx0, inpL, cur);
+            cb(cur, "layer_out", il);
 
-    // 2D input positions
-    struct ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
-    ggml_set_name(pos_h, "pos_h");
-    ggml_set_input(pos_h);
-    struct ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
-    ggml_set_name(pos_w, "pos_w");
-    ggml_set_input(pos_w);
+            inpL = cur;
+        }
+
+        // post-layernorm
+        if (model.post_ln_w) {
+            inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, norm_t, eps, n_layer);
+        }
 
-    struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
-    inp = ggml_reshape_2d(ctx0, inp, num_patches, n_embd);
-    inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
+        // multimodal projection
+        ggml_tensor * embeddings = inpL;
+        embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, n_pos / 4, batch_size);
 
-    struct ggml_tensor * embeddings = inp;
+        embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
+        embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
 
-    // pre-layer norm
-    embeddings = ggml_mul(ctx0, ggml_rms_norm(ctx0, embeddings, eps), model.pre_ln_w);
+        // GELU activation
+        embeddings = ggml_gelu(ctx0, embeddings);
 
-    // loop over layers
-    for (int il = 0; il < n_layer; il++) {
-        struct ggml_tensor * cur = embeddings;
+        // Second linear layer
+        embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
+        embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
 
-        // pre-attention norm
-        cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.layers[il].ln_1_w);
+        if (use_window_attn) {
+            window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / 4);
+            ggml_set_name(window_idx, "window_idx");
+            ggml_set_input(window_idx);
 
-        // self-attention
-        {
-            struct ggml_tensor * Q = ggml_mul_mat(ctx0, model.layers[il].q_w, cur);
+            // embeddings shape: [n_embd, n_patches_x * n_patches_y, batch_size]
+            GGML_ASSERT(batch_size == 1);
+            embeddings = ggml_reshape_2d(ctx0, embeddings, hparams.projection_dim, n_patches_x * n_patches_y / 4);
+            embeddings = ggml_get_rows(ctx0, embeddings, window_idx);
+            embeddings = ggml_reshape_3d(ctx0, embeddings, hparams.projection_dim, n_patches_x * n_patches_y / 4, batch_size);
+        }
 
-            Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_patches);
-            Q = build_rope_2d(ctx0, Q, pos_h, pos_w, hparams.rope_theta);
-            Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
+        // build the graph
+        ggml_build_forward_expand(gf, embeddings);
 
-            struct ggml_tensor * K = ggml_mul_mat(ctx0, model.layers[il].k_w, cur);
+        return gf;
+    }
 
-            K = ggml_reshape_3d(ctx0, K, d_head, n_head, num_patches);
-            K = build_rope_2d(ctx0, K, pos_h, pos_w, hparams.rope_theta);
-            K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
+    ggml_cgraph * build_minicpmv() {
+        const int batch_size = 1;
 
-            struct ggml_tensor * V = ggml_mul_mat(ctx0, model.layers[il].v_w, cur);
+        GGML_ASSERT(model.class_embedding == nullptr);
+        const int n_pos = n_patches;
 
-            V = ggml_reshape_3d(ctx0, V, d_head, n_head, num_patches);
-            V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
+        // position embeddings for the projector (not for ViT)
+        int n_output_dim = clip_n_mmproj_embd(ctx);
+        ggml_tensor * pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_output_dim, n_pos, batch_size);
+        ggml_set_name(pos_embed, "pos_embed");
+        ggml_set_input(pos_embed);
 
-            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
-            KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
+        // for selecting learned pos embd, used by ViT
+        struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos);
+        ggml_set_name(positions, "positions");
+        ggml_set_input(positions);
 
-            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
-            KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head);
-            KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+        ggml_tensor * learned_pos_embd = ggml_get_rows(ctx0, model.position_embeddings, positions);
 
-            cur = ggml_cont_2d(ctx0, KQV, n_embd, num_patches);
+        ggml_tensor * inp = build_inp();
+        ggml_tensor * embeddings = build_vit(
+                                inp, n_patches,
+                                NORM_TYPE_NORMAL,
+                                hparams.ffn_op,
+                                learned_pos_embd,
+                                nullptr);
 
-            cur = ggml_mul_mat(ctx0, model.layers[il].o_w, cur);
-        }
+        // resampler projector (it is just another transformer)
 
-        // re-add the layer input, e.g., residual
-        cur = ggml_add(ctx0, cur, embeddings);
+        ggml_tensor * q = model.mm_model_query;
+        ggml_tensor * v = ggml_mul_mat(ctx0, model.mm_model_kv_proj, embeddings);
 
-        embeddings = cur; // embeddings = residual, cur = hidden_states
+        // norm
+        q = build_norm(q, model.mm_model_ln_q_w, model.mm_model_ln_q_b, NORM_TYPE_NORMAL, eps, -1);
+        v = build_norm(v, model.mm_model_ln_kv_w, model.mm_model_ln_kv_b, NORM_TYPE_NORMAL, eps, -1);
 
-        // pre-ffn norm
-        cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.layers[il].ln_2_w);
+        // k = v + pos_embed
+        ggml_tensor * k = ggml_add(ctx0, v, pos_embed);
 
-        // feed-forward
+        // attention
         {
-            ggml_tensor * gate_proj = ggml_mul_mat(ctx0, model.layers[il].ff_gate_w, cur);
-            ggml_tensor * up_proj   = ggml_mul_mat(ctx0, model.layers[il].ff_up_w,   cur);
-            if (ctx->use_silu) {
-                gate_proj = ggml_silu(ctx0, gate_proj);
-            } else if (ctx->use_gelu) {
-                gate_proj = ggml_gelu(ctx0, gate_proj);
-            } else {
-                GGML_ABORT("Pixtral: Unsupported activation");
+            int n_embd = clip_n_mmproj_embd(ctx);
+            const int d_head = 128;
+            int n_head = n_embd/d_head;
+            int num_query = 96;
+            if (ctx->minicpmv_version == 2) {
+                num_query = 96;
+            } else if (ctx->minicpmv_version == 3) {
+                num_query = 64;
+            } else if (ctx->minicpmv_version == 4) {
+                num_query = 64;
             }
-            cur = ggml_mul(ctx0, up_proj, gate_proj);
-            cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur);
-        }
 
-        // residual 2
-        cur = ggml_add(ctx0, embeddings, cur);
+            ggml_tensor * Q = ggml_add(ctx0,
+                ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q),
+                model.mm_model_attn_q_b);
+            ggml_tensor * K = ggml_add(ctx0,
+                ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k),
+                model.mm_model_attn_k_b);
+            ggml_tensor * V = ggml_add(ctx0,
+                ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v),
+                model.mm_model_attn_v_b);
+
+            Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_query);
+            K = ggml_reshape_3d(ctx0, K, d_head, n_head, n_pos);
+            V = ggml_reshape_3d(ctx0, V, d_head, n_head, n_pos);
+
+            cb(Q, "resampler_Q", -1);
+            cb(K, "resampler_K", -1);
+            cb(V, "resampler_V", -1);
+
+            embeddings = build_attn(
+                model.mm_model_attn_o_w,
+                model.mm_model_attn_o_b,
+                Q, K, V, nullptr, kq_scale, -1);
+            cb(embeddings, "resampler_attn_out", -1);
+        }
+        // layernorm
+        embeddings = build_norm(embeddings, model.mm_model_ln_post_w, model.mm_model_ln_post_b, NORM_TYPE_NORMAL, eps, -1);
+
+        // projection
+        embeddings = ggml_mul_mat(ctx0, model.mm_model_proj, embeddings);
+
+        // build the graph
+        ggml_build_forward_expand(gf, embeddings);
 
-        embeddings = cur;
+        return gf;
     }
 
-    // mistral small 3.1 patch merger
-    // ref: https://github.com/huggingface/transformers/blob/7a3e208892c06a5e278144eaf38c8599a42f53e7/src/transformers/models/mistral3/modeling_mistral3.py#L67
-    if (model.mm_patch_merger_w) {
-        GGML_ASSERT(hparams.spatial_merge_size > 0);
-
-        ggml_tensor * cur = embeddings;
-        cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.mm_input_norm_w);
+    // this graph is used by llava, granite and glm
+    // due to having embedding_stack (used by granite), we cannot reuse build_vit
+    ggml_cgraph * build_llava() {
+        const int batch_size = 1;
+        const int n_pos = n_patches + (model.class_embedding ? 1 : 0);
 
-        // reshape image tokens to 2D grid
-        cur = ggml_reshape_3d(ctx0, cur, n_embd, n_patches_x, n_patches_y);
-        cur = ggml_permute(ctx0, cur, 2, 0, 1, 3); // [x, y, n_embd]
-        cur = ggml_cont(ctx0, cur);
+        GGML_ASSERT(n_patches_x == n_patches_y && "only square images supported");
 
-        // torch.nn.functional.unfold is just an im2col under the hood
-        // we just need a dummy kernel to make it work
-        ggml_tensor * kernel = ggml_view_3d(ctx0, cur, n_merge, n_merge, cur->ne[2], 0, 0, 0);
-        cur = ggml_im2col(ctx0, kernel, cur, n_merge, n_merge, 0, 0, 1, 1, true, inp->type);
+        // Calculate the deepest feature layer based on hparams and projector type
+        int max_feature_layer = n_layer;
+        {
+            // Get the index of the second to last layer; this is the default for models that have a llava projector
+            int il_last = hparams.n_layer - 1;
+            int deepest_feature_layer = -1;
 
-        // project to n_embd
-        cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], cur->ne[1] * cur->ne[2]);
-        cur = ggml_mul_mat(ctx0, model.mm_patch_merger_w, cur);
-        embeddings = cur;
-    }
+            if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
+                il_last += 1;
+            }
 
-    // LlavaMultiModalProjector (always using GELU activation)
-    {
-        embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
-        if (model.mm_1_b) {
-            embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
+            // If we set explicit vision feature layers, only go up to the deepest one
+            // NOTE: only used by granite-vision models for now
+            for (const auto & feature_layer : hparams.vision_feature_layer) {
+                if (feature_layer > deepest_feature_layer) {
+                    deepest_feature_layer = feature_layer;
+                }
+            }
+            max_feature_layer = deepest_feature_layer < 0 ? il_last : deepest_feature_layer;
         }
 
-        embeddings = ggml_gelu(ctx0, embeddings);
-        embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
-        if (model.mm_2_b) {
-            embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
+        ggml_tensor * inp = build_inp();
+
+        if (model.patch_bias) {
+            inp = ggml_add(ctx0, inp, model.patch_bias);
         }
-    }
 
-    // arrangement of the [IMG_BREAK] token
-    {
-        // not efficient, but works
-        // the trick is to view the embeddings as a 3D tensor with shape [n_embd, n_patches_per_row, n_rows]
-        // and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension
-        // after the concatenation, we have a tensor with shape [n_embd, n_patches_per_row + 1, n_rows]
-
-        const int p_y             = n_merge > 0 ? n_patches_y / n_merge : n_patches_y;
-        const int p_x             = n_merge > 0 ? n_patches_x / n_merge : n_patches_x;
-        const int p_total         = p_x * p_y;
-        const int n_embd_text     = embeddings->ne[0];
-        const int n_tokens_output = p_total + p_y - 1; // one [IMG_BREAK] per row, except the last row
-
-        ggml_tensor * cur = ggml_reshape_3d(ctx0, embeddings, n_embd_text, p_x, p_y);
-        ggml_tensor * tok = ggml_new_tensor_3d(ctx0, embeddings->type, n_embd_text, 1, p_y);
-        tok = ggml_scale(ctx0, tok, 0.0); // clear the tensor
-        tok = ggml_add(ctx0, tok, model.token_embd_img_break);
-        cur = ggml_concat(ctx0, cur, tok, 1);
-        embeddings = ggml_view_2d(ctx0, cur,
-            n_embd_text, n_tokens_output,
-            ggml_row_size(cur->type, n_embd_text), 0);
-    }
-
-    // build the graph
-    ggml_build_forward_expand(gf, embeddings);
-
-    return gf;
-}
+        // concat class_embeddings and patch_embeddings
+        if (model.class_embedding) {
+            inp = ggml_concat(ctx0, inp, model.class_embedding, 1);
+        }
 
-static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
-    const auto & model = ctx->vision_model;
-    const auto & hparams = model.hparams;
+        ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos);
+        ggml_set_name(positions, "positions");
+        ggml_set_input(positions);
 
-    const int image_size_width  = imgs.entries[0]->nx;
-    const int image_size_height = imgs.entries[0]->ny;
+        inp = ggml_add(ctx0, inp, ggml_get_rows(ctx0, model.position_embeddings, positions));
 
-    const bool use_window_attn = hparams.n_wa_pattern > 0;
-
-    const int n_wa_pattern         = hparams.n_wa_pattern;
-    const int patch_size           = hparams.patch_size;
-    const int num_patches          = ((image_size_width / patch_size) * (image_size_height / patch_size));
-    const int patches_w            = image_size_width / patch_size;
-    const int patches_h            = image_size_height / patch_size;
-    const int num_positions        = num_patches + (model.class_embedding ? 1 : 0);
-    const int num_position_ids     = num_positions * 4; // m-rope requires 4 dim per position
-    const int n_embd               = hparams.n_embd;
-    const int n_head               = hparams.n_head;
-    const int d_head               = n_embd / n_head;
-    const int n_layer              = hparams.n_layer;
-    const float eps                = hparams.eps;
-
-    int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
-
-    const int batch_size = imgs.entries.size();
-    GGML_ASSERT(batch_size == 1);
-
-    struct ggml_init_params params = {
-        /*.mem_size   =*/ ctx->buf_compute_meta.size(),
-        /*.mem_buffer =*/ ctx->buf_compute_meta.data(),
-        /*.no_alloc   =*/ true,
-    };
+        ggml_tensor * inpL = inp;
 
-    ggml_context_ptr ctx0_ptr(ggml_init(params));
-    auto ctx0 = ctx0_ptr.get();
+        // pre-layernorm
+        if (model.pre_ln_w) {
+            inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, NORM_TYPE_NORMAL, eps, -1);
+            cb(inpL, "pre_ln", -1);
+        }
 
-    struct ggml_cgraph * gf = ggml_new_graph(ctx0);
+        std::vector<ggml_tensor *> embedding_stack;
+        const auto & vision_feature_layer = hparams.vision_feature_layer;
 
-    struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3, batch_size);
-    ggml_set_name(inp_raw, "inp_raw");
-    ggml_set_input(inp_raw);
+        // loop over layers
+        for (int il = 0; il < max_feature_layer; il++) {
+            auto & layer = model.layers[il];
+            ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states
 
-    struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
+            // If this is an embedding feature layer, save the output.
+            // NOTE: 0 index here refers to the input to the encoder.
+            if (vision_feature_layer.find(il) != vision_feature_layer.end()) {
+                embedding_stack.push_back(cur);
+            }
 
-    GGML_ASSERT(image_size_width  % (patch_size * 2) == 0);
-    GGML_ASSERT(image_size_height % (patch_size * 2) == 0);
+            // layernorm1
+            cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, NORM_TYPE_NORMAL, eps, il);
+            cb(cur, "layer_inp_normed", il);
 
-    auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
-    inp = ggml_add(ctx0, inp, inp_1);
+            // self-attention
+            {
+                ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.q_w, cur);
+                if (layer.q_b) {
+                    Qcur = ggml_add(ctx0, Qcur, layer.q_b);
+                }
 
-    inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3));  // [w, h, c, b] -> [c, w, h, b]
-    inp = ggml_reshape_4d(
-        ctx0, inp,
-        n_embd * 2, patches_w / 2, patches_h, batch_size);
-    inp = ggml_reshape_4d(
-        ctx0, inp,
-        n_embd * 2, patches_w / 2, 2, batch_size * (patches_h / 2));
-    inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3));
-    inp = ggml_reshape_3d(
-        ctx0, inp,
-        n_embd, patches_w * patches_h, batch_size);
+                ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.k_w, cur);
+                if (layer.k_b) {
+                    Kcur = ggml_add(ctx0, Kcur, layer.k_b);
+                }
 
-    if (model.patch_bias) {
-        // inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp));
-        inp = ggml_add(ctx0, inp, model.patch_bias);
-    }
-    struct ggml_tensor * embeddings     = inp;
-    struct ggml_tensor * window_mask    = nullptr;
-    struct ggml_tensor * window_idx     = nullptr;
-    struct ggml_tensor * inv_window_idx = nullptr;
+                ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.v_w, cur);
+                if (layer.v_b) {
+                    Vcur = ggml_add(ctx0, Vcur, layer.v_b);
+                }
 
-    struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids);
-    ggml_set_name(positions, "positions");
-    ggml_set_input(positions);
+                Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_pos);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_pos);
 
-    // pre-layernorm
-    if (model.pre_ln_w) {
-        embeddings = ggml_rms_norm(ctx0, embeddings, eps);
-        ggml_set_name(embeddings, "pre_ln");
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
-        embeddings = ggml_mul(ctx0, embeddings, model.pre_ln_w);
-    }
+                cur = build_attn(layer.o_w, layer.o_b,
+                    Qcur, Kcur, Vcur, nullptr, kq_scale, il);
+                cb(cur, "attn_out", il);
+            }
 
-    if (use_window_attn) {
-        // handle window attention inputs
-        inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4);
-        ggml_set_name(inv_window_idx, "inv_window_idx");
-        ggml_set_input(inv_window_idx);
-        // mask for window attention
-        window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, num_positions, num_positions);
-        ggml_set_name(window_mask, "window_mask");
-        ggml_set_input(window_mask);
+            // re-add the layer input, e.g., residual
+            cur = ggml_add(ctx0, cur, inpL);
 
-        // embeddings shape: [n_embd, patches_w * patches_h, batch_size]
-        GGML_ASSERT(batch_size == 1);
-        embeddings = ggml_reshape_2d(ctx0, embeddings, n_embd * 4, patches_w * patches_h * batch_size / 4);
-        embeddings = ggml_get_rows(ctx0, embeddings, inv_window_idx);
-        embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd, patches_w * patches_h, batch_size);
-    }
+            inpL = cur; // inpL = residual, cur = hidden_states
 
-    // loop over layers
-    for (int il = 0; il < n_layer; il++) {
-        struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
+            cb(cur, "ffn_inp", il);
 
-        // rmsnorm1
-        cur = ggml_rms_norm(ctx0, cur, eps);
-        cur = ggml_mul(ctx0, cur, model.layers[il].ln_1_w);
+            // layernorm2
+            cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, NORM_TYPE_NORMAL, eps, il);
+            cb(cur, "ffn_inp_normed", il);
 
-        // self-attention
-        {
+            // ffn
+            cur = build_ffn(cur,
+                layer.ff_up_w, layer.ff_up_b,
+                layer.ff_gate_w, layer.ff_gate_b,
+                layer.ff_down_w, layer.ff_down_b,
+                hparams.ffn_op, il);
 
-            struct ggml_tensor * Q =
-                ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b);
+            cb(cur, "ffn_out", il);
 
-            Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size);
-            Q = ggml_rope_multi(
-                ctx0, Q, positions, nullptr,
-                d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
-            Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
-            Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size);
+            // residual 2
+            cur = ggml_add(ctx0, inpL, cur);
+            cb(cur, "layer_out", il);
 
-            struct ggml_tensor * K =
-                ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b);
+            inpL = cur;
+        }
 
-            K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
-            K = ggml_rope_multi(
-                ctx0, K, positions, nullptr,
-                d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
-            K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
-            K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
+        // post-layernorm
+        if (model.post_ln_w) {
+            inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, NORM_TYPE_NORMAL, eps, -1);
+        }
 
-            struct ggml_tensor * V =
-                ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b);
+        ggml_tensor * embeddings = inpL;
 
-            V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size);
-            V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
-            V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
+        // process vision feature layers (used by granite)
+        {
+            // final layer is a vision feature layer
+            if (vision_feature_layer.find(max_feature_layer) != vision_feature_layer.end()) {
+                embedding_stack.push_back(inpL);
+            }
 
-            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
-            const bool full_attn = use_window_attn ? (il + 1) % n_wa_pattern == 0 : true;
-            if (full_attn) {
-                KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
-            } else {
-                KQ = ggml_soft_max_ext(ctx0, KQ, window_mask, 1.0f / sqrtf((float)d_head), 0.0f);
+            // If feature layers are explicitly set, stack them (if we have multiple)
+            if (!embedding_stack.empty()) {
+                embeddings = embedding_stack[0];
+                for (size_t i = 1; i < embedding_stack.size(); i++) {
+                    embeddings = ggml_concat(ctx0, embeddings, embedding_stack[i], 0);
+                }
             }
+        }
 
-            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
-            KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size);
-            KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+        // llava projector (also used by granite)
+        if (ctx->has_llava_projector) {
+            embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
 
-            cur = ggml_cont_3d(ctx0, KQV, n_embd, num_positions, batch_size);
-        }
+            ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
+            ggml_set_name(patches, "patches");
+            ggml_set_input(patches);
 
-        // attention output
-        cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].o_w, cur), model.layers[il].o_b);
+            // shape [1, 576, 1024]
+            // ne is whcn, ne = [1024, 576, 1, 1]
+            embeddings = ggml_get_rows(ctx0, embeddings, patches);
 
-        // re-add the layer input, e.g., residual
-        cur = ggml_add(ctx0, cur, embeddings);
+            // print_tensor_info(embeddings, "embeddings");
 
-        embeddings = cur; // embeddings = residual, cur = hidden_states
+            // llava projector
+            if (ctx->proj_type == PROJECTOR_TYPE_MLP) {
+                embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
+                embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
 
-        // rms norm2
-        cur = ggml_rms_norm(ctx0, cur, eps);
-        cur = ggml_mul(ctx0, cur, model.layers[il].ln_2_w);
+                embeddings = ggml_gelu(ctx0, embeddings);
+                if (model.mm_2_w) {
+                    embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
+                    embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
+                }
+            }
+            else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) {
+                embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
+                embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
+                // ggml_tensor_printf(embeddings, "mm_0_w",0,true,false);
+                // First LayerNorm
+                embeddings = ggml_norm(ctx0, embeddings, eps);
+                embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_1_w),
+                                    model.mm_1_b);
+
+                // GELU activation
+                embeddings = ggml_gelu(ctx0, embeddings);
+
+                // Second linear layer
+                embeddings = ggml_mul_mat(ctx0, model.mm_3_w, embeddings);
+                embeddings = ggml_add(ctx0, embeddings, model.mm_3_b);
+
+                // Second LayerNorm
+                embeddings = ggml_norm(ctx0, embeddings, eps);
+                embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_4_w),
+                                    model.mm_4_b);
+            }
+            else if (ctx->proj_type == PROJECTOR_TYPE_LDP) {
+                // MobileVLM projector
+                int n_patch = 24;
+                ggml_tensor * mlp_1 = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, embeddings);
+                mlp_1 = ggml_add(ctx0, mlp_1, model.mm_model_mlp_1_b);
+                mlp_1 = ggml_gelu(ctx0, mlp_1);
+                ggml_tensor * mlp_3 = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, mlp_1);
+                mlp_3 = ggml_add(ctx0, mlp_3, model.mm_model_mlp_3_b);
+                // mlp_3 shape = [1, 576, 2048], ne = [2048, 576, 1, 1]
+
+                // block 1
+                ggml_tensor * block_1 = nullptr;
+                {
+                    // transpose from [1, 576, 2048] --> [1, 2048, 576] --> [1, 2048, 24, 24]
+                    mlp_3 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_3, 1, 0, 2, 3));
+                    mlp_3 = ggml_reshape_4d(ctx0, mlp_3, n_patch, n_patch, mlp_3->ne[1], mlp_3->ne[2]);
+                    // stride = 1, padding = 1, bias is nullptr
+                    block_1 = ggml_conv_2d_dw(ctx0, model.mm_model_block_1_block_0_0_w, mlp_3, 1, 1, 1, 1, 1, 1);
+
+                    // layer norm
+                    // // block_1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1]
+                    block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 2, 0, 3));
+                    // block_1 shape = [1, 24, 24, 2048], ne = [2048, 24, 24, 1]
+                    block_1 = ggml_norm(ctx0, block_1, eps);
+                    block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_1_block_0_1_w), model.mm_model_block_1_block_0_1_b);
+                    block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3));
+
+                    // block_1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1]
+                    // hardswish
+                    ggml_tensor * block_1_hw = ggml_hardswish(ctx0, block_1);
+
+                    block_1 = ggml_pool_2d(ctx0, block_1_hw, GGML_OP_POOL_AVG, block_1_hw->ne[0], block_1_hw->ne[1], block_1_hw->ne[0], block_1_hw->ne[1], 0, 0);
+                    // block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1]
+                    // pointwise conv
+                    block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]);
+                    block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_1_fc1_w, block_1);
+                    block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc1_b);
+                    block_1 = ggml_relu(ctx0, block_1);
+                    block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_1_fc2_w, block_1);
+                    block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc2_b);
+                    block_1 = ggml_hardsigmoid(ctx0, block_1);
+                    // block_1_hw shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1], block_1 shape = [1, 2048], ne = [2048, 1, 1, 1]
+                    block_1 = ggml_reshape_4d(ctx0, block_1, 1, 1, block_1->ne[0], block_1->ne[1]);
+                    block_1 = ggml_mul(ctx0, block_1_hw, block_1);
+
+                    int w = block_1->ne[0], h = block_1->ne[1];
+                    block_1 = ggml_reshape_3d(ctx0, block_1, w*h, block_1->ne[2], block_1->ne[3]);
+                    block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3));
+
+                    // block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1]
+                    block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_2_0_w, block_1);
+                    block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]);
+
+                    // block_1 shape = [1, 24, 24, 2048], ne = [2048, 24, 24, 1]
+                    block_1 = ggml_norm(ctx0, block_1, eps);
+                    block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_1_block_2_1_w), model.mm_model_block_1_block_2_1_b);
+                    block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3));
+                    // block1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1]
+                    // residual
+                    block_1 = ggml_add(ctx0, mlp_3, block_1);
+                }
 
-        // mlp
-        // ffn_up
-        auto cur_up = ggml_mul_mat(ctx0, model.layers[il].ff_up_w, cur);
-        cur_up = ggml_add(ctx0, cur_up, model.layers[il].ff_up_b);
+                // block_2
+                {
+                    // stride = 2
+                    block_1 = ggml_conv_2d_dw(ctx0, model.mm_model_block_2_block_0_0_w, block_1, 2, 2, 1, 1, 1, 1);
+
+                    // block_1 shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1]
+                    // layer norm
+                    block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 2, 0, 3));
+                    // block_1 shape = [1, 12, 12, 2048], ne = [2048, 12, 12, 1]
+                    block_1 = ggml_norm(ctx0, block_1, eps);
+                    block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_2_block_0_1_w), model.mm_model_block_2_block_0_1_b);
+                    block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3));
+                    // block_1 shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1]
+                    // hardswish
+                    ggml_tensor * block_1_hw = ggml_hardswish(ctx0, block_1);
+
+                    // not sure the parameters is right for globalAvgPooling
+                    block_1 = ggml_pool_2d(ctx0, block_1_hw, GGML_OP_POOL_AVG, block_1_hw->ne[0], block_1_hw->ne[1], block_1_hw->ne[0], block_1_hw->ne[1], 0, 0);
+                    // block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1]
+                    // pointwise conv
+                    block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]);
+                    block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_1_fc1_w, block_1);
+                    block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc1_b);
+                    block_1 = ggml_relu(ctx0, block_1);
+                    block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_1_fc2_w, block_1);
+                    block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc2_b);
+                    block_1 = ggml_hardsigmoid(ctx0, block_1);
+
+                    // block_1_hw shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1], block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1]
+                    block_1 = ggml_reshape_4d(ctx0, block_1, 1, 1, block_1->ne[0], block_1->ne[1]);
+                    block_1 = ggml_mul(ctx0, block_1_hw, block_1);
+
+                    int w = block_1->ne[0], h = block_1->ne[1];
+                    block_1 = ggml_reshape_3d(ctx0, block_1, w*h, block_1->ne[2], block_1->ne[3]);
+                    block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3));
+                    // block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1]
+                    block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_2_0_w, block_1);
+                    block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]);
+
+
+                    // block_1 shape = [1, 12, 12, 2048], ne = [2048, 12, 12, 1]
+                    block_1 = ggml_norm(ctx0, block_1, eps);
+                    block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_2_block_2_1_w), model.mm_model_block_2_block_2_1_b);
+                    block_1 = ggml_reshape_3d(ctx0, block_1, block_1->ne[0], block_1->ne[1] * block_1->ne[2], block_1->ne[3]);
+                    // block_1 shape = [1, 144, 2048], ne = [2048, 144, 1]
+                }
+                embeddings = block_1;
+            }
+            else if (ctx->proj_type == PROJECTOR_TYPE_LDPV2)
+            {
+                int n_patch = 24;
+                ggml_tensor * mlp_0 = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
+                mlp_0 = ggml_add(ctx0, mlp_0, model.mm_model_mlp_0_b);
+                mlp_0 = ggml_gelu(ctx0, mlp_0);
+                ggml_tensor * mlp_2 = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, mlp_0);
+                mlp_2 = ggml_add(ctx0, mlp_2, model.mm_model_mlp_2_b);
+                // mlp_2 ne = [2048, 576, 1, 1]
+                // // AVG Pool Layer 2*2, strides = 2
+                mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 0, 2, 3));
+                // mlp_2 ne = [576, 2048, 1, 1]
+                mlp_2 = ggml_reshape_4d(ctx0, mlp_2, n_patch, n_patch, mlp_2->ne[1], mlp_2->ne[2]);
+                // mlp_2 ne [24, 24, 2048, 1]
+                mlp_2 = ggml_pool_2d(ctx0, mlp_2, GGML_OP_POOL_AVG, 2, 2, 2, 2, 0, 0);
+                // weight ne = [3, 3, 2048, 1]
+                ggml_tensor * peg_0 = ggml_conv_2d_dw(ctx0, model.mm_model_peg_0_w, mlp_2, 1, 1, 1, 1, 1, 1);
+                peg_0 = ggml_cont(ctx0, ggml_permute(ctx0, peg_0, 1, 2, 0, 3));
+                peg_0 = ggml_add(ctx0, peg_0, model.mm_model_peg_0_b);
+                mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 2, 0, 3));
+                peg_0 = ggml_add(ctx0, peg_0, mlp_2);
+                peg_0 = ggml_reshape_3d(ctx0, peg_0, peg_0->ne[0], peg_0->ne[1] * peg_0->ne[2], peg_0->ne[3]);
+                embeddings = peg_0;
+            }
+            else {
+                GGML_ABORT("fatal error");
+            }
+        }
 
-        auto cur_gate = ggml_mul_mat(ctx0, model.layers[il].ff_gate_w, cur);
-        cur_gate = ggml_add(ctx0, cur_gate, model.layers[il].ff_gate_b);
-        // TODO : only 2 of these 3 are actually used, should we remove one of them?
-        if (ctx->use_gelu) {
-            cur_gate = ggml_gelu_inplace(ctx0, cur_gate);
-        } else if (ctx->use_silu) {
-            cur_gate = ggml_silu_inplace(ctx0, cur_gate);
-        } else {
-            cur_gate = ggml_gelu_quick_inplace(ctx0, cur_gate);
+        // glm projector
+        else if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
+            size_t gridsz = (size_t)sqrt(embeddings->ne[1]);
+            embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings,1,0,2,3));
+            embeddings = ggml_reshape_3d(ctx0, embeddings, gridsz, gridsz, embeddings->ne[1]);
+            embeddings = ggml_conv_2d(ctx0, model.mm_model_adapter_conv_w, embeddings, 2, 2, 0, 0, 1, 1);
+            embeddings = ggml_reshape_3d(ctx0, embeddings,embeddings->ne[0]*embeddings->ne[1] , embeddings->ne[2], batch_size);
+            embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings, 1, 0, 2, 3));
+            embeddings = ggml_add(ctx0, embeddings, model.mm_model_adapter_conv_b);
+            // GLU
+            {
+                embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
+                embeddings = ggml_norm(ctx0, embeddings, eps);
+                embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
+                embeddings = ggml_gelu_inplace(ctx0, embeddings);
+                ggml_tensor * x = embeddings;
+                embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings);
+                x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x);
+                embeddings = ggml_silu_inplace(ctx0, embeddings);
+                embeddings = ggml_mul(ctx0, embeddings,x);
+                embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings);
+            }
+            // arrangement of BOI/EOI token embeddings
+            // note: these embeddings are not present in text model, hence we cannot process them as text tokens
+            // see: https://huggingface.co/THUDM/glm-edge-v-2b/blob/main/siglip.py#L53
+            {
+                embeddings = ggml_concat(ctx0, model.mm_glm_tok_boi, embeddings, 1); // BOI
+                embeddings = ggml_concat(ctx0, embeddings, model.mm_glm_tok_eoi, 1); // EOI
+            }
         }
-        cur = ggml_mul(ctx0, cur_gate, cur_up);
 
-        // ffn_down
-        cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur);
-        cur = ggml_add(ctx0, cur, model.layers[il].ff_down_b);
+        else {
+            GGML_ABORT("llava: unknown projector type");
+        }
 
-        // residual 2
-        cur = ggml_add(ctx0, embeddings, cur);
+        // build the graph
+        ggml_build_forward_expand(gf, embeddings);
 
-        embeddings = cur;
+        return gf;
     }
 
-    // post-layernorm
-    if (model.post_ln_w) {
-        embeddings = ggml_rms_norm(ctx0, embeddings, eps);
-        ggml_set_name(embeddings, "post_ln");
+private:
+    //
+    // utility functions
+    //
 
-        embeddings = ggml_mul(ctx0, embeddings, model.post_ln_w);
+    void cb(ggml_tensor * cur, const char * name, int il) const {
+        // TODO: implement this
+        GGML_UNUSED(cur);
+        GGML_UNUSED(name);
+        GGML_UNUSED(il);
     }
 
-    embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, num_positions / 4, batch_size);
-
-    embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
-    embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
-
-    // GELU activation
-    embeddings = ggml_gelu(ctx0, embeddings);
-
-    // Second linear layer
-    embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
-    embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
+    // build vision transformer (ViT) cgraph
+    // this function should cover most of the models
+    // if your model has specific features, you should probably duplicate this function
+    ggml_tensor * build_vit(
+                ggml_tensor * inp,
+                int64_t n_pos,
+                norm_type norm_t,
+                ffn_op_type ffn_t,
+                ggml_tensor * learned_pos_embd,
+                std::function<ggml_tensor *(ggml_tensor *, const clip_layer &)> add_pos
+            ) {
+        if (model.patch_bias) {
+            inp = ggml_add(ctx0, inp, model.patch_bias);
+            cb(inp, "patch_bias", -1);
+        }
 
-    if (use_window_attn) {
-        window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_positions / 4);
-        ggml_set_name(window_idx, "window_idx");
-        ggml_set_input(window_idx);
+        if (learned_pos_embd) {
+            inp = ggml_add(ctx0, inp, learned_pos_embd);
+            cb(inp, "pos_embed", -1);
+        }
 
-        // embeddings shape: [n_embd, patches_w * patches_h, batch_size]
-        GGML_ASSERT(batch_size == 1);
-        embeddings = ggml_reshape_2d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4);
-        embeddings = ggml_get_rows(ctx0, embeddings, window_idx);
-        embeddings = ggml_reshape_3d(ctx0, embeddings, hparams.projection_dim, patches_w * patches_h / 4, batch_size);
-    }
+        ggml_tensor * inpL = inp;
 
-    // build the graph
-    ggml_build_forward_expand(gf, embeddings);
+        // pre-layernorm
+        if (model.pre_ln_w) {
+            inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1);
+            cb(inpL, "pre_ln", -1);
+        }
 
-    return gf;
-}
+        // loop over layers
+        for (int il = 0; il < n_layer; il++) {
+            auto & layer = model.layers[il];
+            ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states
 
-static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) {
-    const auto & model = ctx->vision_model;
-    const auto & hparams = model.hparams;
+            // layernorm1
+            cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, norm_t, eps, il);
+            cb(cur, "layer_inp_normed", il);
 
-    const int image_size = hparams.image_size;
-    int image_size_width  = image_size;
-    int image_size_height = image_size;
+            // self-attention
+            {
+                ggml_tensor * Qcur = ggml_mul_mat(ctx0, layer.q_w, cur);
+                if (layer.q_b) {
+                    Qcur = ggml_add(ctx0, Qcur, layer.q_b);
+                }
 
-    if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
-        LOG_DBG("%s: %d %d\n", __func__, load_image_size.width, load_image_size.height);
-        image_size_width  = load_image_size.width;
-        image_size_height = load_image_size.height;
-        if (is_inf) {
-            image_size_width  = imgs.entries[0]->nx;
-            image_size_height = imgs.entries[0]->ny;
-        }
-    }
+                ggml_tensor * Kcur = ggml_mul_mat(ctx0, layer.k_w, cur);
+                if (layer.k_b) {
+                    Kcur = ggml_add(ctx0, Kcur, layer.k_b);
+                }
 
-    else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
-        // use the image's native resolution when image is avaible
-        if (is_inf) {
-        // if (imgs->data->nx && imgs->data->ny) {
-            image_size_width  = imgs.entries[0]->nx;
-            image_size_height = imgs.entries[0]->ny;
-        }
-    }
+                ggml_tensor * Vcur = ggml_mul_mat(ctx0, layer.v_w, cur);
+                if (layer.v_b) {
+                    Vcur = ggml_add(ctx0, Vcur, layer.v_b);
+                }
 
-    const int patch_size           = hparams.patch_size;
-    const int num_patches          = ((image_size_width / patch_size) * (image_size_height / patch_size));
-    const int patches_w            = image_size_width / patch_size;
-    const int patches_h            = image_size_height / patch_size;
-    const int num_positions        = num_patches + (model.class_embedding ? 1 : 0);
-    const int num_position_ids     = ctx->proj_type == PROJECTOR_TYPE_QWEN2VL ? num_positions * 4 : num_positions;
-    const int n_embd               = hparams.n_embd;
-    const int n_head               = hparams.n_head;
-    const int d_head               = n_embd / n_head;
-    const float eps                = hparams.eps;
-    int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
+                Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_pos);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_pos);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_pos);
 
-    const int batch_size = imgs.entries.size();
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
 
-    if (ctx->has_llava_projector
-            || ctx->proj_type == PROJECTOR_TYPE_MINICPMV
-            || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
-        GGML_ASSERT(batch_size == 1);
-    }
+                if (add_pos) {
+                    Qcur = add_pos(Qcur, layer);
+                    Kcur = add_pos(Kcur, layer);
+                    cb(Qcur, "Qcur_pos", il);
+                    cb(Kcur, "Kcur_pos", il);
+                }
 
-    struct ggml_init_params params = {
-        /*.mem_size   =*/ ctx->buf_compute_meta.size(),
-        /*.mem_buffer =*/ ctx->buf_compute_meta.data(),
-        /*.no_alloc   =*/ true,
-    };
+                cur = build_attn(layer.o_w, layer.o_b,
+                    Qcur, Kcur, Vcur, nullptr, kq_scale, il);
+                cb(cur, "attn_out", il);
+            }
 
-    ggml_context_ptr ctx0_ptr(ggml_init(params));
-    auto ctx0 = ctx0_ptr.get();
+            // re-add the layer input, e.g., residual
+            cur = ggml_add(ctx0, cur, inpL);
 
-    struct ggml_cgraph * gf = ggml_new_graph(ctx0);
+            inpL = cur; // inpL = residual, cur = hidden_states
 
-    struct ggml_tensor * inp_raw = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3, batch_size);
-    ggml_set_name(inp_raw, "inp_raw");
-    ggml_set_input(inp_raw);
+            cb(cur, "ffn_inp", il);
 
-    struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
+            // layernorm2
+            cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, norm_t, eps, il);
+            cb(cur, "ffn_inp_normed", il);
 
-    if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
-        GGML_ASSERT(image_size_width  % (patch_size * 2) == 0);
-        GGML_ASSERT(image_size_height % (patch_size * 2) == 0);
+            // ffn
+            cur = build_ffn(cur,
+                layer.ff_up_w, layer.ff_up_b,
+                layer.ff_gate_w, layer.ff_gate_b,
+                layer.ff_down_w, layer.ff_down_b,
+                ffn_t, il);
 
-        auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
-        inp = ggml_add(ctx0, inp, inp_1);
-        inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 2, 0, 3));  // [w, h, c, b] -> [c, w, h, b]
-        inp = ggml_reshape_4d(
-            ctx0, inp,
-            n_embd * 2, patches_w / 2, patches_h, batch_size);
-        inp = ggml_reshape_4d(
-            ctx0, inp,
-            n_embd * 2, patches_w / 2, 2, batch_size * (patches_h / 2));
-        inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 0, 2, 1, 3));
-        inp = ggml_reshape_3d(
-            ctx0, inp,
-            n_embd, patches_w * patches_h, batch_size);
-    }
-    else {
-        inp = ggml_reshape_3d(ctx0, inp, num_patches, n_embd, batch_size);
-        inp = ggml_cont(ctx0, ggml_permute(ctx0, inp, 1, 0, 2, 3));
-    }
+            cb(cur, "ffn_out", il);
 
-    if (model.patch_bias) {
-        // inp = ggml_add(ctx0, inp, ggml_repeat(ctx0, model.patch_bias, inp));
-        inp = ggml_add(ctx0, inp, model.patch_bias);
-    }
-    struct ggml_tensor * embeddings = inp;
-    struct ggml_tensor * pos_embed = nullptr;
-
-    // concat class_embeddings and patch_embeddings
-    if (model.class_embedding) {
-        embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd, num_positions, batch_size);
-        embeddings = ggml_scale(ctx0, embeddings, 0.0f); // set to all zeros
-        embeddings = ggml_acc(ctx0, embeddings, model.class_embedding,
-                embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
-        embeddings = ggml_acc(ctx0, embeddings, inp,
-                embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
-    }
+            // residual 2
+            cur = ggml_add(ctx0, inpL, cur);
+            cb(cur, "layer_out", il);
 
-    struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids);
-    ggml_set_name(positions, "positions");
-    ggml_set_input(positions);
+            inpL = cur;
+        }
 
-    if (ctx->proj_type != PROJECTOR_TYPE_QWEN2VL) { // qwen2vl does NOT use learned position embeddings
-        embeddings =
-            ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions));
+        // post-layernorm
+        if (model.post_ln_w) {
+            inpL = build_norm(inpL, model.post_ln_w, model.post_ln_b, norm_t, eps, -1);
+        }
+        return inpL;
     }
 
-    if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
-        int pos_w = image_size_width/patch_size;
-        int pos_h = image_size_height/patch_size;
-        int n_output_dim = clip_n_mmproj_embd(ctx);
-        pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_output_dim, pos_w * pos_h, 1);
-        ggml_set_name(pos_embed, "pos_embed");
-        ggml_set_input(pos_embed);
+    // build the input after conv2d (inp_raw --> patches)
+    // returns tensor with shape [n_embd, n_patches]
+    ggml_tensor * build_inp() {
+        ggml_tensor * inp_raw = build_inp_raw();
+        ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
+        inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd);
+        inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
+        return inp;
     }
 
-    // pre-layernorm
-    if (model.pre_ln_w) {
-        embeddings = ggml_norm(ctx0, embeddings, eps);
-        ggml_set_name(embeddings, "pre_ln");
-
-        embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.pre_ln_w), model.pre_ln_b);
+    ggml_tensor * build_inp_raw() {
+        ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, img.nx, img.ny, 3);
+        ggml_set_name(inp_raw, "inp_raw");
+        ggml_set_input(inp_raw);
+        return inp_raw;
     }
 
-    std::vector<struct ggml_tensor *> embedding_stack;
-    const auto & vision_feature_layer = hparams.vision_feature_layer;
+    ggml_tensor * build_norm(
+            ggml_tensor * cur,
+            ggml_tensor * mw,
+            ggml_tensor * mb,
+            norm_type type,
+            float norm_eps,
+            int il) const {
 
-    // loop over layers
-    for (int il = 0; il < ctx->max_feature_layer; il++) {
-        struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
+        cur = type == NORM_TYPE_RMS
+            ? ggml_rms_norm(ctx0, cur, norm_eps)
+            : ggml_norm(ctx0, cur, norm_eps);
 
-        // If this is an embedding feature layer, save the output.
-        // NOTE: 0 index here refers to the input to the encoder.
-        if (vision_feature_layer.find(il) != vision_feature_layer.end()) {
-            embedding_stack.push_back(embeddings);
+        if (mw || mb) {
+            cb(cur, "norm", il);
         }
 
-        //const size_t nb_q_w = model.layers[il].q_w->nb[0];
-
-        // layernorm1
-        {
-            cur = ggml_norm(ctx0, cur, eps);
-
-            cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_1_w),
-                           model.layers[il].ln_1_b);
-        }
-
-        // self-attention
-        {
-
-            struct ggml_tensor * Q =
-                ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b);
-
-            Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size);
-            if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
-                Q = ggml_rope_multi(
-                    ctx0, Q, positions, nullptr,
-                    d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
+        if (mw) {
+            cur = ggml_mul(ctx0, cur, mw);
+            if (mb) {
+                cb(cur, "norm_w", il);
             }
-            Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
-            Q = ggml_reshape_3d(ctx0, Q, d_head, num_positions, n_head * batch_size);
-
-            struct ggml_tensor * K =
-                ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b);
-
-            K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
-            if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
-                K = ggml_rope_multi(
-                    ctx0, K, positions, nullptr,
-                    d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
-            }
-            K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
-            K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
-
-            struct ggml_tensor * V =
-                ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].v_w, cur), model.layers[il].v_b);
-
-            V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size);
-            V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
-            V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
-
-            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
-            KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
-            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
-            KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_positions, n_head, batch_size);
-            KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
-
-            cur = ggml_cont_3d(ctx0, KQV, n_embd, num_positions, batch_size);
         }
 
-        // attention output
-        cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].o_w, cur), model.layers[il].o_b);
+        if (mb) {
+            cur = ggml_add(ctx0, cur, mb);
+        }
 
-        // re-add the layer input, e.g., residual
-        cur = ggml_add(ctx0, cur, embeddings);
+        return cur;
+    }
 
-        embeddings = cur; // embeddings = residual, cur = hidden_states
+    ggml_tensor * build_ffn(
+            ggml_tensor * cur,
+            ggml_tensor * up,
+            ggml_tensor * up_b,
+            ggml_tensor * gate,
+            ggml_tensor * gate_b,
+            ggml_tensor * down,
+            ggml_tensor * down_b,
+            ffn_op_type type_op,
+            int il) const {
 
-        // layernorm2
-        {
-            cur = ggml_norm(ctx0, cur, eps);
+        ggml_tensor * tmp = up ? ggml_mul_mat(ctx0, up, cur) : cur;
+        cb(tmp, "ffn_up", il);
 
-            cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ln_2_w), model.layers[il].ln_2_b);
+        if (up_b) {
+            tmp = ggml_add(ctx0, tmp, up_b);
+            cb(tmp, "ffn_up_b", il);
         }
 
-        cur = ggml_mul_mat(ctx0, model.layers[il].ff_up_w, cur);
-        cur = ggml_add(ctx0, cur, model.layers[il].ff_up_b);
+        if (gate) {
+            cur = ggml_mul_mat(ctx0, gate, cur);
+            cb(cur, "ffn_gate", il);
 
-        if (ctx->use_gelu) {
-            cur = ggml_gelu_inplace(ctx0, cur);
-        } else if (ctx->use_silu) {
-            cur = ggml_silu_inplace(ctx0, cur);
+            if (gate_b) {
+                cur = ggml_add(ctx0, cur, gate_b);
+                cb(cur, "ffn_gate_b", il);
+            }
         } else {
-            cur = ggml_gelu_quick_inplace(ctx0, cur);
+            cur = tmp;
         }
 
-        cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur);
-        cur = ggml_add(ctx0, cur, model.layers[il].ff_down_b);
+        switch (type_op) {
+            case FFN_SILU:
+                {
+                    cur = ggml_silu(ctx0, cur);
+                    cb(cur, "ffn_silu", il);
+                } break;
+            case FFN_GELU:
+                {
+                    cur = ggml_gelu(ctx0, cur);
+                    cb(cur, "ffn_gelu", il);
+                } break;
+            case FFN_GELU_QUICK:
+                {
+                    cur = ggml_gelu_quick(ctx0, cur);
+                    cb(cur, "ffn_relu", il);
+                } break;
+        }
 
-        // residual 2
-        cur = ggml_add(ctx0, embeddings, cur);
+        // we only support parallel ffn for now
+        if (gate) {
+            cur = ggml_mul(ctx0, cur, tmp);
+            cb(cur, "ffn_gate_par", il);
+        }
 
-        embeddings = cur;
-    }
+        if (down) {
+            cur = ggml_mul_mat(ctx0, down, cur);
+        }
 
-    // post-layernorm
-    if (model.post_ln_w) {
-        embeddings = ggml_norm(ctx0, embeddings, eps);
-        ggml_set_name(embeddings, "post_ln");
+        if (down_b) {
+            cb(cur, "ffn_down", il);
+        }
 
-        embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b);
-    }
+        if (down_b) {
+            cur = ggml_add(ctx0, cur, down_b);
+        }
 
-    // final layer is a vision feature layer
-    if (vision_feature_layer.find(ctx->max_feature_layer) != vision_feature_layer.end()) {
-        embedding_stack.push_back(embeddings);
+        return cur;
     }
 
-    // If feature layers are explicitly set, stack them (if we have multiple)
-    if (!embedding_stack.empty()) {
-        embeddings = embedding_stack[0];
-        for (size_t i = 1; i < embedding_stack.size(); i++) {
-            embeddings = ggml_concat(ctx0, embeddings, embedding_stack[i], 0);
-        }
-    }
+    ggml_tensor * build_attn(
+            ggml_tensor * wo,
+            ggml_tensor * wo_b,
+            ggml_tensor * q_cur,
+            ggml_tensor * k_cur,
+            ggml_tensor * v_cur,
+            ggml_tensor * kq_mask,
+            float kq_scale,
+            int il) const {
+        // these nodes are added to the graph together so that they are not reordered
+        // by doing so, the number of splits in the graph is reduced
+        ggml_build_forward_expand(gf, q_cur);
+        ggml_build_forward_expand(gf, k_cur);
+        ggml_build_forward_expand(gf, v_cur);
 
-    // llava projector
-    if (ctx->has_llava_projector) {
-        embeddings = ggml_reshape_2d(ctx0, embeddings, embeddings->ne[0], embeddings->ne[1]);
+        ggml_tensor * q = ggml_permute(ctx0, q_cur, 0, 2, 1, 3);
+        //cb(q, "q", il);
 
-        struct ggml_tensor * patches = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
-        ggml_set_name(patches, "patches");
-        ggml_set_input(patches);
+        ggml_tensor * k = ggml_permute(ctx0, k_cur, 0, 2, 1, 3);
+        //cb(k, "k", il);
 
-        // shape [1, 576, 1024]
-        // ne is whcn, ne = [1024, 576, 1, 1]
-        embeddings = ggml_get_rows(ctx0, embeddings, patches);
+        ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3);
+        v = ggml_cont(ctx0, v);
+        //cb(k, "v", il);
 
-        // print_tensor_info(embeddings, "embeddings");
+        ggml_tensor * cur;
 
-        // llava projector
-        if (ctx->proj_type == PROJECTOR_TYPE_MLP) {
-            embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
-            embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
+        // TODO @ngxson : support flash attention
+        {
+            const auto n_tokens = q->ne[1];
+            const auto n_head   = q->ne[2];
+            // const auto n_kv     = k->ne[1]; // for flash attention
 
-            embeddings = ggml_gelu(ctx0, embeddings);
-            if (model.mm_2_w) {
-                embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
-                embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
-            }
-        }
-        else if (ctx->proj_type == PROJECTOR_TYPE_MLP_NORM) {
-            embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
-            embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
-            // ggml_tensor_printf(embeddings, "mm_0_w",0,true,false);
-            // First LayerNorm
-            embeddings = ggml_norm(ctx0, embeddings, eps);
-            embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_1_w),
-                                model.mm_1_b);
-
-            // GELU activation
-            embeddings = ggml_gelu(ctx0, embeddings);
-
-            // Second linear layer
-            embeddings = ggml_mul_mat(ctx0, model.mm_3_w, embeddings);
-            embeddings = ggml_add(ctx0, embeddings, model.mm_3_b);
-
-            // Second LayerNorm
-            embeddings = ggml_norm(ctx0, embeddings, eps);
-            embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_4_w),
-                                model.mm_4_b);
-        }
-        else if (ctx->proj_type == PROJECTOR_TYPE_LDP) {
-            // MobileVLM projector
-            int n_patch = 24;
-            struct ggml_tensor * mlp_1 = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, embeddings);
-            mlp_1 = ggml_add(ctx0, mlp_1, model.mm_model_mlp_1_b);
-            mlp_1 = ggml_gelu(ctx0, mlp_1);
-            struct ggml_tensor * mlp_3 = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, mlp_1);
-            mlp_3 = ggml_add(ctx0, mlp_3, model.mm_model_mlp_3_b);
-            // mlp_3 shape = [1, 576, 2048], ne = [2048, 576, 1, 1]
-
-            // block 1
-            struct ggml_tensor * block_1 = nullptr;
-            {
-                // transpose from [1, 576, 2048] --> [1, 2048, 576] --> [1, 2048, 24, 24]
-                mlp_3 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_3, 1, 0, 2, 3));
-                mlp_3 = ggml_reshape_4d(ctx0, mlp_3, n_patch, n_patch, mlp_3->ne[1], mlp_3->ne[2]);
-                // stride = 1, padding = 1, bias is nullptr
-                block_1 = ggml_conv_2d_dw(ctx0, model.mm_model_block_1_block_0_0_w, mlp_3, 1, 1, 1, 1, 1, 1);
-
-                // layer norm
-                // // block_1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1]
-                block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 2, 0, 3));
-                // block_1 shape = [1, 24, 24, 2048], ne = [2048, 24, 24, 1]
-                block_1 = ggml_norm(ctx0, block_1, eps);
-                block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_1_block_0_1_w), model.mm_model_block_1_block_0_1_b);
-                block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3));
-
-                // block_1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1]
-                // hardswish
-                struct ggml_tensor * block_1_hw = ggml_hardswish(ctx0, block_1);
-
-                block_1 = ggml_pool_2d(ctx0, block_1_hw, GGML_OP_POOL_AVG, block_1_hw->ne[0], block_1_hw->ne[1], block_1_hw->ne[0], block_1_hw->ne[1], 0, 0);
-                // block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1]
-                // pointwise conv
-                block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]);
-                block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_1_fc1_w, block_1);
-                block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc1_b);
-                block_1 = ggml_relu(ctx0, block_1);
-                block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_1_fc2_w, block_1);
-                block_1 = ggml_add(ctx0, block_1, model.mm_model_block_1_block_1_fc2_b);
-                block_1 = ggml_hardsigmoid(ctx0, block_1);
-                // block_1_hw shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1], block_1 shape = [1, 2048], ne = [2048, 1, 1, 1]
-                block_1 = ggml_reshape_4d(ctx0, block_1, 1, 1, block_1->ne[0], block_1->ne[1]);
-                block_1 = ggml_mul(ctx0, block_1_hw, block_1);
-
-                int w = block_1->ne[0], h = block_1->ne[1];
-                block_1 = ggml_reshape_3d(ctx0, block_1, w*h, block_1->ne[2], block_1->ne[3]);
-                block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3));
-
-                // block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1]
-                block_1 = ggml_mul_mat(ctx0, model.mm_model_block_1_block_2_0_w, block_1);
-                block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]);
-
-                // block_1 shape = [1, 24, 24, 2048], ne = [2048, 24, 24, 1]
-                block_1 = ggml_norm(ctx0, block_1, eps);
-                block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_1_block_2_1_w), model.mm_model_block_1_block_2_1_b);
-                block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3));
-                // block1 shape = [1, 2048, 24, 24], ne = [24, 24, 2048, 1]
-                // residual
-                block_1 = ggml_add(ctx0, mlp_3, block_1);
-            }
+            ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
+            // F32 may not needed for vision encoders?
+            // ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
 
-            // block_2
-            {
-                // stride = 2
-                block_1 = ggml_conv_2d_dw(ctx0, model.mm_model_block_2_block_0_0_w, block_1, 2, 2, 1, 1, 1, 1);
-
-                // block_1 shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1]
-                // layer norm
-                block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 2, 0, 3));
-                // block_1 shape = [1, 12, 12, 2048], ne = [2048, 12, 12, 1]
-                block_1 = ggml_norm(ctx0, block_1, eps);
-                block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_2_block_0_1_w), model.mm_model_block_2_block_0_1_b);
-                block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 2, 0, 1, 3));
-                // block_1 shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1]
-                // hardswish
-                struct ggml_tensor * block_1_hw = ggml_hardswish(ctx0, block_1);
-
-                // not sure the parameters is right for globalAvgPooling
-                block_1 = ggml_pool_2d(ctx0, block_1_hw, GGML_OP_POOL_AVG, block_1_hw->ne[0], block_1_hw->ne[1], block_1_hw->ne[0], block_1_hw->ne[1], 0, 0);
-                // block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1]
-                // pointwise conv
-                block_1 = ggml_reshape_2d(ctx0, block_1, block_1->ne[0]*block_1->ne[1]*block_1->ne[2], block_1->ne[3]);
-                block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_1_fc1_w, block_1);
-                block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc1_b);
-                block_1 = ggml_relu(ctx0, block_1);
-                block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_1_fc2_w, block_1);
-                block_1 = ggml_add(ctx0, block_1, model.mm_model_block_2_block_1_fc2_b);
-                block_1 = ggml_hardsigmoid(ctx0, block_1);
-
-                // block_1_hw shape = [1, 2048, 12, 12], ne = [12, 12, 2048, 1], block_1 shape = [1, 2048, 1, 1], ne = [1, 1, 2048, 1]
-                block_1 = ggml_reshape_4d(ctx0, block_1, 1, 1, block_1->ne[0], block_1->ne[1]);
-                block_1 = ggml_mul(ctx0, block_1_hw, block_1);
-
-                int w = block_1->ne[0], h = block_1->ne[1];
-                block_1 = ggml_reshape_3d(ctx0, block_1, w*h, block_1->ne[2], block_1->ne[3]);
-                block_1 = ggml_cont(ctx0, ggml_permute(ctx0, block_1, 1, 0, 2, 3));
-                // block_1 shape = [1, 24*24, 2048], ne = [24*24, 2048, 1]
-                block_1 = ggml_mul_mat(ctx0, model.mm_model_block_2_block_2_0_w, block_1);
-                block_1 = ggml_reshape_4d(ctx0, block_1, block_1->ne[0], w, h, block_1->ne[3]);
-
-
-                // block_1 shape = [1, 12, 12, 2048], ne = [2048, 12, 12, 1]
-                block_1 = ggml_norm(ctx0, block_1, eps);
-                block_1 = ggml_add(ctx0, ggml_mul(ctx0, block_1, model.mm_model_block_2_block_2_1_w), model.mm_model_block_2_block_2_1_b);
-                block_1 = ggml_reshape_3d(ctx0, block_1, block_1->ne[0], block_1->ne[1] * block_1->ne[2], block_1->ne[3]);
-                // block_1 shape = [1, 144, 2048], ne = [2048, 144, 1]
-            }
-            embeddings = block_1;
-        }
-        else if (ctx->proj_type == PROJECTOR_TYPE_LDPV2)
-        {
-            int n_patch = 24;
-            struct ggml_tensor * mlp_0 = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
-            mlp_0 = ggml_add(ctx0, mlp_0, model.mm_model_mlp_0_b);
-            mlp_0 = ggml_gelu(ctx0, mlp_0);
-            struct ggml_tensor * mlp_2 = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, mlp_0);
-            mlp_2 = ggml_add(ctx0, mlp_2, model.mm_model_mlp_2_b);
-            // mlp_2 ne = [2048, 576, 1, 1]
-            // // AVG Pool Layer 2*2, strides = 2
-            mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 0, 2, 3));
-            // mlp_2 ne = [576, 2048, 1, 1]
-            mlp_2 = ggml_reshape_4d(ctx0, mlp_2, n_patch, n_patch, mlp_2->ne[1], mlp_2->ne[2]);
-            // mlp_2 ne [24, 24, 2048, 1]
-            mlp_2 = ggml_pool_2d(ctx0, mlp_2, GGML_OP_POOL_AVG, 2, 2, 2, 2, 0, 0);
-            // weight ne = [3, 3, 2048, 1]
-            struct ggml_tensor * peg_0 = ggml_conv_2d_dw(ctx0, model.mm_model_peg_0_w, mlp_2, 1, 1, 1, 1, 1, 1);
-            peg_0 = ggml_cont(ctx0, ggml_permute(ctx0, peg_0, 1, 2, 0, 3));
-            peg_0 = ggml_add(ctx0, peg_0, model.mm_model_peg_0_b);
-            mlp_2 = ggml_cont(ctx0, ggml_permute(ctx0, mlp_2, 1, 2, 0, 3));
-            peg_0 = ggml_add(ctx0, peg_0, mlp_2);
-            peg_0 = ggml_reshape_3d(ctx0, peg_0, peg_0->ne[0], peg_0->ne[1] * peg_0->ne[2], peg_0->ne[3]);
-            embeddings = peg_0;
-        }
-        else {
-            GGML_ABORT("fatal error");
-        }
-    }
-    // minicpmv projector
-    else if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
-        struct ggml_tensor * q = model.mm_model_query;
-        { // layernorm
-            q = ggml_norm(ctx0, q, eps);
-            q = ggml_add(ctx0, ggml_mul(ctx0, q, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
-        }
-        struct ggml_tensor * v = ggml_mul_mat(ctx0, model.mm_model_kv_proj, embeddings);
-        { // layernorm
-            v = ggml_norm(ctx0, v, eps);
-            v = ggml_add(ctx0, ggml_mul(ctx0, v, model.mm_model_ln_kv_w), model.mm_model_ln_kv_b);
-        }
-        struct ggml_tensor * k;
-        { // position
-            // q = ggml_add(ctx0, q, model.mm_model_pos_embed);
-            k = ggml_add(ctx0, v, pos_embed);
+            kq = ggml_soft_max_ext(ctx0, kq, kq_mask, kq_scale, 0.0f);
+
+            ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
+            cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
+            cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
         }
 
-        { // attention
-            int n_embd = clip_n_mmproj_embd(ctx);
-            const int d_head = 128;
-            int n_head = n_embd/d_head;
-            int num_query = 96;
-            if (ctx->minicpmv_version == 2) {
-                num_query = 96;
-            }
-            else if (ctx->minicpmv_version == 3) {
-                num_query = 64;
-            }
-            else if (ctx->minicpmv_version == 4) {
-                num_query = 64;
-            }
+        cb(cur, "kqv_out", il);
 
-            struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b);
-            struct ggml_tensor * K = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k), model.mm_model_attn_k_b);
-            struct ggml_tensor * V = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v), model.mm_model_attn_v_b);
-            // permute
-            Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_query, batch_size);
-            Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
-            Q = ggml_reshape_3d(ctx0, Q, d_head, num_query, n_head * batch_size);
-            K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
-            K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
-            K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
-            V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size);
-            V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
-            V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
-            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
-            KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
-            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
-            KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_query, n_head, batch_size);
-            KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
-            KQV = ggml_cont_3d(ctx0, KQV, n_embd, num_query, batch_size);
-
-            embeddings = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_o_w, KQV), model.mm_model_attn_o_b);
-        }
-        { // layernorm
-            embeddings = ggml_norm(ctx0, embeddings, eps);
-            embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_post_w), model.mm_model_ln_post_b);
+        if (wo) {
+            cur = ggml_mul_mat(ctx0, wo, cur);
         }
-        embeddings = ggml_mul_mat(ctx0, model.mm_model_proj, embeddings);
-    }
 
-    // glm projector
-    else if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
-        size_t gridsz = (size_t)sqrt(embeddings->ne[1]);
-        embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings,1,0,2,3));
-        embeddings = ggml_reshape_3d(ctx0, embeddings, gridsz, gridsz, embeddings->ne[1]);
-        embeddings = ggml_conv_2d(ctx0, model.mm_model_adapter_conv_w, embeddings, 2, 2, 0, 0, 1, 1);
-        embeddings = ggml_reshape_3d(ctx0, embeddings,embeddings->ne[0]*embeddings->ne[1] , embeddings->ne[2], batch_size);
-        embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings, 1, 0, 2, 3));
-        embeddings = ggml_add(ctx0, embeddings, model.mm_model_adapter_conv_b);
-        // GLU
-        {
-            embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
-            embeddings = ggml_norm(ctx0, embeddings, eps);
-            embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
-            embeddings = ggml_gelu_inplace(ctx0, embeddings);
-            struct ggml_tensor * x = embeddings;
-            embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings);
-            x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x);
-            embeddings = ggml_silu_inplace(ctx0, embeddings);
-            embeddings = ggml_mul(ctx0, embeddings,x);
-            embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings);
-        }
-        // arrangement of BOI/EOI token embeddings
-        // note: these embeddings are not present in text model, hence we cannot process them as text tokens
-        // see: https://huggingface.co/THUDM/glm-edge-v-2b/blob/main/siglip.py#L53
-        {
-            embeddings = ggml_concat(ctx0, model.mm_glm_tok_boi, embeddings, 1); // BOI
-            embeddings = ggml_concat(ctx0, embeddings, model.mm_glm_tok_eoi, 1); // EOI
+        if (wo_b) {
+            cur = ggml_add(ctx0, cur, wo_b);
         }
-    }
 
-    else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
-        embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, num_positions / 4, batch_size);
+        return cur;
+    }
 
-        embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
-        embeddings = ggml_add(ctx0, embeddings, model.mm_0_b);
+    // implementation of the 2D RoPE without adding a new op in ggml
+    // this is not efficient (use double the memory), but works on all backends
+    // TODO: there was a more efficient which relies on ggml_view and ggml_rope_ext_inplace, but the rope inplace does not work well with non-contiguous tensors ; we should fix that and revert back to the original implementation in https://github.com/ggml-org/llama.cpp/pull/13065
+    static ggml_tensor * build_rope_2d(
+        ggml_context * ctx0,
+        ggml_tensor * cur,
+        ggml_tensor * pos_h,
+        ggml_tensor * pos_w,
+        const float freq_base
+    ) {
+        const int64_t n_dim  = cur->ne[0];
+        const int64_t n_head = cur->ne[1];
+        const int64_t n_pos  = cur->ne[2];
 
-        // GELU activation
-        embeddings = ggml_gelu(ctx0, embeddings);
+        // for example, if we have cur tensor of shape (n_dim=8, n_head, n_pos)
+        // we will have a list of 4 inv_freq: 1e-0, 1e-1, 1e-2, 1e-3
+        // first half of cur will use 1e-0, 1e-2 (even)
+        // second half of cur will use 1e-1, 1e-3 (odd)
+        // the trick here is to rotate just half of n_dim, so inv_freq will automatically be even
+        //  ^ don't ask me why, it's math! -2(2i) / n_dim == -2i / (n_dim/2)
+        // then for the second half, we use freq_scale to shift the inv_freq
+        //  ^ why? replace (2i) with (2i+1) in the above equation
+        const float freq_scale_odd = std::pow(freq_base, (float)-2/n_dim);
 
-        // Second linear layer
-        embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
-        embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
+        // first half
+        ggml_tensor * first;
+        {
+            first = ggml_view_3d(ctx0, cur,
+                n_dim/2, n_head, n_pos,
+                ggml_row_size(cur->type, n_dim),
+                ggml_row_size(cur->type, n_dim*n_head),
+                0);
+            first = ggml_rope_ext(
+                ctx0,
+                first,
+                pos_h,      // positions
+                nullptr,    // freq factors
+                n_dim/2,    // n_dims
+                0, 0, freq_base,
+                1.0f, 0.0f, 1.0f, 0.0f, 0.0f
+            );
+        }
+
+        // second half
+        ggml_tensor * second;
+        {
+            second = ggml_view_3d(ctx0, cur,
+                n_dim/2, n_head, n_pos,
+                ggml_row_size(cur->type, n_dim),
+                ggml_row_size(cur->type, n_dim*n_head),
+                n_dim/2 * ggml_element_size(cur));
+            second = ggml_cont(ctx0, second); // copy, because ggml_rope don't play well with non-contiguous tensors
+            second = ggml_rope_ext(
+                ctx0,
+                second,
+                pos_w,      // positions
+                nullptr,    // freq factors
+                n_dim/2,    // n_dims
+                0, 0, freq_base,
+                freq_scale_odd,
+                0.0f, 1.0f, 0.0f, 0.0f
+            );
+        }
+
+        cur = ggml_concat(ctx0, first, second, 0);
+        return cur;
     }
 
-    // build the graph
-    ggml_build_forward_expand(gf, embeddings);
+};
 
-    return gf;
-}
+static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
+    GGML_ASSERT(imgs.entries.size() == 1 && "n_batch > 1 is not supported");
+    clip_graph graph(ctx, *imgs.entries[0]);
 
-static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) {
     ggml_cgraph * res;
+
     switch (ctx->proj_type) {
         case PROJECTOR_TYPE_GEMMA3:
         case PROJECTOR_TYPE_IDEFICS3:
             {
-                GGML_ASSERT(imgs.entries.size() == 1);
-                res = clip_image_build_graph_siglip(ctx, *imgs.entries[0]);
+                res = graph.build_siglip();
             } break;
         case PROJECTOR_TYPE_PIXTRAL:
             {
-                GGML_ASSERT(imgs.entries.size() == 1);
-                res = clip_image_build_graph_pixtral(ctx, *imgs.entries[0]);
+                res = graph.build_pixtral();
             } break;
+        case PROJECTOR_TYPE_QWEN2VL:
         case PROJECTOR_TYPE_QWEN25VL:
             {
-                res = clip_image_build_graph_qwen25vl(ctx, imgs);
+                res = graph.build_qwen2vl();
+            } break;
+        case PROJECTOR_TYPE_MINICPMV:
+            {
+                res = graph.build_minicpmv();
             } break;
         default:
             {
-                // TODO: we should have one build_* function per model
-                res = clip_image_build_graph_legacy(ctx, imgs, load_image_size, is_inf);
+                res = graph.build_llava();
             } break;
     }
     return res;
@@ -1656,7 +1680,7 @@ struct clip_model_loader {
                 const char * name = gguf_get_tensor_name(ctx_gguf.get(), i);
                 const size_t offset = gguf_get_tensor_offset(ctx_gguf.get(), i);
                 enum ggml_type type = gguf_get_tensor_type(ctx_gguf.get(), i);
-                struct ggml_tensor * cur = ggml_get_tensor(meta, name);
+                ggml_tensor * cur = ggml_get_tensor(meta, name);
                 size_t tensor_size = ggml_nbytes(cur);
                 model_size += tensor_size;
                 LOG_DBG("%s: tensor[%d]: n_dims = %d, name = %s, tensor_size=%zu, offset=%zu, shape:[%" PRIu64 ", %" PRIu64 ", %" PRIu64 ", %" PRIu64 "], type = %s\n",
@@ -1667,6 +1691,7 @@ struct clip_model_loader {
 
     void load_hparams() {
         auto & hparams = ctx_clip.vision_model.hparams;
+        std::string log_ffn_op; // for logging
 
         // projector type
         std::string proj_type;
@@ -1682,10 +1707,7 @@ struct clip_model_loader {
 
         // other hparams
         {
-            get_i32(KEY_MINICPMV_VERSION, ctx_clip.minicpmv_version, false);
-
-            get_bool(KEY_USE_GELU, ctx_clip.use_gelu, false);
-            get_bool(KEY_USE_SILU, ctx_clip.use_silu, false);
+            get_i32(KEY_MINICPMV_VERSION, ctx_clip.minicpmv_version, false); // legacy
 
             get_u32(KEY_N_EMBD,         hparams.n_embd);
             get_u32(KEY_N_HEAD,         hparams.n_head);
@@ -1703,6 +1725,26 @@ struct clip_model_loader {
                                         || ctx_clip.proj_type == PROJECTOR_TYPE_LDP
                                         || ctx_clip.proj_type == PROJECTOR_TYPE_LDPV2;
 
+            {
+                bool use_gelu = false;
+                bool use_silu = false;
+                get_bool(KEY_USE_GELU, use_gelu, false);
+                get_bool(KEY_USE_SILU, use_silu, false);
+                if (use_gelu && use_silu) {
+                    throw std::runtime_error(string_format("%s: both use_gelu and use_silu are set to true\n", __func__));
+                }
+                if (use_gelu) {
+                    hparams.ffn_op = FFN_GELU;
+                    log_ffn_op = "gelu";
+                } else if (use_silu) {
+                    hparams.ffn_op = FFN_SILU;
+                    log_ffn_op = "silu";
+                } else {
+                    hparams.ffn_op = FFN_GELU_QUICK;
+                    log_ffn_op = "gelu_quick";
+                }
+            }
+
             {
                 std::string mm_patch_merge_type;
                 get_string(KEY_MM_PATCH_MERGE_TYPE, mm_patch_merge_type, false);
@@ -1736,30 +1778,6 @@ struct clip_model_loader {
                 hparams.vision_feature_layer.insert(layer);
             }
 
-            // Calculate the deepest feature layer based on hparams and projector type
-            // NOTE: This is only used by build_graph_legacy()
-            {
-                // Get the index of the second to last layer; this is the default for models that have a llava projector
-                int n_layer = hparams.n_layer - 1;
-                int deepest_feature_layer = -1;
-
-                if (ctx_clip.proj_type == PROJECTOR_TYPE_MINICPMV
-                        || ctx_clip.proj_type == PROJECTOR_TYPE_GLM_EDGE
-                        || ctx_clip.proj_type == PROJECTOR_TYPE_QWEN2VL
-                        || ctx_clip.proj_type == PROJECTOR_TYPE_QWEN25VL) {
-                    n_layer += 1;
-                }
-
-                // If we set explicit vision feature layers, only go up to the deepest one
-                // NOTE: only used by granite-vision models for now
-                for (const auto & feature_layer : hparams.vision_feature_layer) {
-                    if (feature_layer > deepest_feature_layer) {
-                        deepest_feature_layer = feature_layer;
-                    }
-                }
-                ctx_clip.max_feature_layer = deepest_feature_layer < 0 ? n_layer : deepest_feature_layer;
-            }
-
             // model-specific params
             switch (ctx_clip.proj_type) {
                 case PROJECTOR_TYPE_MINICPMV:
@@ -1777,6 +1795,14 @@ struct clip_model_loader {
                         hparams.rope_theta = 10000.0f;
                         get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false);
                     } break;
+                case PROJECTOR_TYPE_GEMMA3:
+                    {
+                        // default value (used by all model sizes in gemma 3 family)
+                        // number of patches for each **side** is reduced by a factor of 4
+                        hparams.proj_scale_factor = 4;
+                        // test model (tinygemma3) has a different value, we optionally read it
+                        get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
+                    } break;
                 case PROJECTOR_TYPE_QWEN25VL:
                     {
                         get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern);
@@ -1786,12 +1812,19 @@ struct clip_model_loader {
             }
 
             LOG_INF("%s: projector:          %s\n", __func__, proj_type.c_str());
+            LOG_INF("%s: n_embd:             %d\n", __func__, hparams.n_embd);
+            LOG_INF("%s: n_head:             %d\n", __func__, hparams.n_head);
+            LOG_INF("%s: n_ff:               %d\n", __func__, hparams.n_ff);
+            LOG_INF("%s: n_layer:            %d\n", __func__, hparams.n_layer);
+            LOG_INF("%s: projection_dim:     %d\n", __func__, hparams.projection_dim);
+            LOG_INF("%s: image_size:         %d\n", __func__, hparams.image_size);
+            LOG_INF("%s: patch_size:         %d\n", __func__, hparams.patch_size);
+            LOG_INF("\n");
             LOG_INF("%s: has_llava_proj:     %d\n", __func__, ctx_clip.has_llava_projector);
             LOG_INF("%s: minicpmv_version:   %d\n", __func__, ctx_clip.minicpmv_version);
             LOG_INF("%s: proj_scale_factor:  %d\n", __func__, hparams.proj_scale_factor);
             LOG_INF("%s: n_wa_pattern:       %d\n", __func__, hparams.n_wa_pattern);
-            LOG_INF("%s: use_silu:           %d\n", __func__, ctx_clip.use_silu);
-            LOG_INF("%s: use_gelu:           %d\n", __func__, ctx_clip.use_gelu);
+            LOG_INF("%s: ffn_op:             %s\n", __func__, log_ffn_op.c_str());
             LOG_INF("%s: model size:         %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0);
             LOG_INF("%s: metadata size:      %.2f MiB\n", __func__, ggml_get_mem_size(ctx_meta.get()) / 1024.0 / 1024.0);
         }
@@ -1821,14 +1854,14 @@ struct clip_model_loader {
 
         // helper function
         auto get_tensor = [&](const std::string & name, bool required = true) {
-            struct ggml_tensor * cur = ggml_get_tensor(ctx_meta.get(), name.c_str());
+            ggml_tensor * cur = ggml_get_tensor(ctx_meta.get(), name.c_str());
             if (!cur && required) {
                 throw std::runtime_error(string_format("%s: unable to find tensor %s\n", __func__, name.c_str()));
             }
             if (cur) {
                 tensors_to_load.push_back(cur);
                 // add tensors to context
-                struct ggml_tensor * data_tensor = ggml_dup_tensor(ctx_clip.ctx_data.get(), cur);
+                ggml_tensor * data_tensor = ggml_dup_tensor(ctx_clip.ctx_data.get(), cur);
                 ggml_set_name(data_tensor, cur->name);
                 cur = data_tensor;
             }
@@ -2034,7 +2067,7 @@ struct clip_model_loader {
             ctx_clip.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(ctx_clip.ctx_data.get(), buft));
             ggml_backend_buffer_set_usage(ctx_clip.buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS);
             for (auto & t : tensors_to_load) {
-                struct ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data.get(), t->name);
+                ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data.get(), t->name);
                 const size_t offset = tensor_offset[t->name];
                 fin.seekg(offset, std::ios::beg);
                 if (!fin) {
@@ -2063,15 +2096,12 @@ struct clip_model_loader {
         // create a fake batch
         clip_image_f32_batch batch;
         clip_image_f32_ptr img(clip_image_f32_init());
-        clip_image_size image_size;
-        image_size.width  = ctx_clip.vision_model.hparams.image_size;
-        image_size.height = ctx_clip.vision_model.hparams.image_size;
-        img->nx = image_size.width;
-        img->ny = image_size.height;
-        img->buf.resize(image_size.width * image_size.height * 3);
+        img->nx = ctx_clip.vision_model.hparams.image_size;
+        img->ny = ctx_clip.vision_model.hparams.image_size;
+        img->buf.resize(img->nx * img->ny * 3);
         batch.entries.push_back(std::move(img));
 
-        ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, batch, image_size, false);
+        ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, batch);
         ggml_backend_sched_reserve(ctx_clip.sched.get(), gf);
         for (size_t i = 0; i < ctx_clip.backend_ptrs.size(); ++i) {
             ggml_backend_t backend = ctx_clip.backend_ptrs[i];
@@ -2976,11 +3006,13 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
         int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0);
         n_patches = x_patch * y_patch;
     } else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
-        n_patches = 256;
+        int n_per_side = params.image_size / params.patch_size;
+        int n_per_side_2d_pool = n_per_side / params.proj_scale_factor;
+        n_patches = n_per_side_2d_pool * n_per_side_2d_pool;
     } else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
-        n_patches /= ctx->vision_model.hparams.proj_scale_factor;
+        n_patches /= params.proj_scale_factor;
     } else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
-        int n_merge = ctx->vision_model.hparams.spatial_merge_size;
+        int n_merge = params.spatial_merge_size;
         int n_patches_x = img->nx / params.patch_size / (n_merge > 0 ? n_merge : 1);
         int n_patches_y = img->ny / params.patch_size / (n_merge > 0 ? n_merge : 1);
         n_patches = n_patches_y*n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
@@ -3088,15 +3120,15 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
     const clip_image_f32_batch & imgs = *imgs_c_ptr;
     int batch_size = imgs.entries.size();
 
-    if (ctx->has_llava_projector
-            || ctx->proj_type == PROJECTOR_TYPE_MINICPMV
-            || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
-        GGML_ASSERT(batch_size == 1);
+    // TODO @ngxson : implement batch size > 1 as a loop
+    //                we don't need true batching support because the cgraph will gonna be big anyway
+    if (batch_size != 1) {
+        return false; // only support batch size of 1
     }
 
     // build the inference graph
     ggml_backend_sched_reset(ctx->sched.get());
-    ggml_cgraph * gf = clip_image_build_graph(ctx, imgs, ctx->load_image_size, true);
+    ggml_cgraph * gf = clip_image_build_graph(ctx, imgs);
     ggml_backend_sched_alloc_graph(ctx->sched.get(), gf);
 
     // set inputs
@@ -3108,14 +3140,14 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
 
     const int patch_size    = hparams.patch_size;
     const int num_patches   = ((image_size_width / patch_size) * (image_size_height / patch_size));
-    const int num_positions = num_patches + (model.class_embedding ? 1 : 0);
+    const int n_pos = num_patches + (model.class_embedding ? 1 : 0);
     const int pos_w = ctx->load_image_size.width  / patch_size;
     const int pos_h = ctx->load_image_size.height / patch_size;
 
     const bool use_window_attn = hparams.n_wa_pattern > 0; // for qwen2.5vl
 
     auto get_inp_tensor = [&gf](const char * name) {
-        struct ggml_tensor * inp = ggml_graph_get_tensor(gf, name);
+        ggml_tensor * inp = ggml_graph_get_tensor(gf, name);
         if (inp == nullptr) {
             GGML_ABORT("Failed to get tensor %s", name);
         }
@@ -3224,7 +3256,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
                 const int merge_ratio = 2;
                 const int pw = image_size_width  / patch_size;
                 const int ph = image_size_height / patch_size;
-                std::vector<int> positions(num_positions * 4);
+                std::vector<int> positions(n_pos * 4);
                 int ptr = 0;
                 for (int y = 0; y < ph; y += merge_ratio) {
                     for (int x = 0; x < pw; x += merge_ratio) {
@@ -3301,7 +3333,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
                 }
 
                 const int mpow = merge_ratio * merge_ratio;
-                std::vector<int> positions(num_positions * 4);
+                std::vector<int> positions(n_pos * 4);
 
                 int ptr = 0;
                 for (int y = 0; y < iph; y += merge_ratio) {
@@ -3327,14 +3359,14 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
             {
                 // set the 2D positions
                 int n_patches_per_col = image_size_width / patch_size;
-                std::vector<int> pos_data(num_positions);
+                std::vector<int> pos_data(n_pos);
                 // dimension H
-                for (int i = 0; i < num_positions; i++) {
+                for (int i = 0; i < n_pos; i++) {
                     pos_data[i] = i / n_patches_per_col;
                 }
                 set_input_i32("pos_h", pos_data);
                 // dimension W
-                for (int i = 0; i < num_positions; i++) {
+                for (int i = 0; i < n_pos; i++) {
                     pos_data[i] = i % n_patches_per_col;
                 }
                 set_input_i32("pos_w", pos_data);
@@ -3342,8 +3374,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
         case PROJECTOR_TYPE_GLM_EDGE:
         {
             // llava and other models
-            std::vector<int32_t> positions(num_positions);
-            for (int i = 0; i < num_positions; i++) {
+            std::vector<int32_t> positions(n_pos);
+            for (int i = 0; i < n_pos; i++) {
                 positions[i] = i;
             }
             set_input_i32("positions", positions);
@@ -3354,8 +3386,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
         case PROJECTOR_TYPE_LDPV2:
             {
                 // llava and other models
-                std::vector<int32_t> positions(num_positions);
-                for (int i = 0; i < num_positions; i++) {
+                std::vector<int32_t> positions(n_pos);
+                for (int i = 0; i < n_pos; i++) {
                     positions[i] = i;
                 }
                 set_input_i32("positions", positions);
@@ -3396,7 +3428,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
     }
 
     // the last node is the embedding tensor
-    struct ggml_tensor * embeddings = ggml_graph_node(gf, -1);
+    ggml_tensor * embeddings = ggml_graph_node(gf, -1);
 
     // copy the embeddings to the location passed by the user
     ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings));
@@ -3427,7 +3459,7 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i
 
     for (int i = 0; i < n_tensors; ++i) {
         const char * name = gguf_get_tensor_name(ctx_src, i);
-        struct ggml_tensor * cur = ggml_get_tensor(ctx_data, name);
+        ggml_tensor * cur = ggml_get_tensor(ctx_data, name);
         gguf_add_tensor(ctx_out, cur);
     }
 
@@ -3448,7 +3480,7 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i
 
     for (int i = 0; i < n_tensors; ++i) {
         const std::string name = gguf_get_tensor_name(ctx_src, i);
-        struct ggml_tensor * cur = ggml_get_tensor(ctx_data, name.c_str());
+        ggml_tensor * cur = ggml_get_tensor(ctx_data, name.c_str());
 
         enum ggml_type new_type;
         void * new_data;