]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
clip : refactor set input for cgraph + fix qwen2.5vl input (#13136)
authorXuan-Son Nguyen <redacted>
Mon, 28 Apr 2025 10:18:59 +0000 (12:18 +0200)
committerGitHub <redacted>
Mon, 28 Apr 2025 10:18:59 +0000 (12:18 +0200)
* clip : refactor set input for cgraph

* more strict assert

* minicpmv : use clip_n_mmproj_embd instead of copying the same code everywhere

* split qwen2 and qwen2.5 code blocks

* minor style fix

examples/llava/clip.cpp

index 3cd27d5b17a083f436af4a93d6fc20be1d69f556..8c5d56cc17ae94f5eb2fbb33987047a9acc9649a 100644 (file)
@@ -170,8 +170,8 @@ struct clip_hparams {
     std::vector<int32_t> image_grid_pinpoints;
     int32_t image_crop_resolution;
     std::unordered_set<int32_t> vision_feature_layer;
-    int32_t attn_window_size;
-    int32_t n_wa_pattern;
+    int32_t attn_window_size = 0;
+    int32_t n_wa_pattern = 0;
 };
 
 struct clip_layer {
@@ -325,7 +325,6 @@ struct clip_ctx {
     float image_std[3];
     bool use_gelu = false;
     bool use_silu = false;
-    int32_t ftype = 1;
 
     gguf_context_ptr ctx_gguf;
     ggml_context_ptr ctx_data;
@@ -776,7 +775,6 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_
     const int image_size_width  = imgs.entries[0]->nx;
     const int image_size_height = imgs.entries[0]->ny;
 
-    const bool use_mrope       = ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL;
     const bool use_window_attn = hparams.n_wa_pattern > 0;
 
     const int n_wa_pattern         = hparams.n_wa_pattern;
@@ -785,10 +783,11 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_
     const int patches_w            = image_size_width / patch_size;
     const int patches_h            = image_size_height / patch_size;
     const int num_positions        = num_patches + (model.class_embedding ? 1 : 0);
-    const int num_position_ids     = use_mrope ? num_positions * 4 : num_positions;
+    const int num_position_ids     = num_positions * 4; // m-rope requires 4 dim per position
     const int hidden_size          = hparams.hidden_size;
     const int n_head               = hparams.n_head;
     const int d_head               = hidden_size / n_head;
+    const int n_layer              = hparams.n_layer;
     const float eps                = hparams.eps;
 
     int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4};
@@ -870,7 +869,7 @@ static ggml_cgraph * clip_image_build_graph_qwen25vl(clip_ctx * ctx, const clip_
     }
 
     // loop over layers
-    for (int il = 0; il < ctx->max_feature_layer; il++) {
+    for (int il = 0; il < n_layer; il++) {
         struct ggml_tensor * cur = embeddings; // embeddings = residual, cur = hidden_states
 
         // rmsnorm1
@@ -1115,15 +1114,8 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
     if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
         int pos_w = image_size_width/patch_size;
         int pos_h = image_size_height/patch_size;
-        if (ctx->minicpmv_version == 2) {
-            pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 4096, pos_w * pos_h, 1);
-        }
-        else if (ctx->minicpmv_version == 3) {
-            pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1);
-        }
-        else if (ctx->minicpmv_version == 4) {
-            pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 3584, pos_w * pos_h, 1);
-        }
+        int n_output_dim = clip_n_mmproj_embd(ctx);
+        pos_embed = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_output_dim, pos_w * pos_h, 1);
         ggml_set_name(pos_embed, "pos_embed");
         ggml_set_input(pos_embed);
     }
@@ -1461,23 +1453,17 @@ static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_im
         }
 
         { // attention
-            int hidden_size = 4096;
+            int hidden_size = clip_n_mmproj_embd(ctx);
             const int d_head = 128;
             int n_head = hidden_size/d_head;
             int num_query = 96;
             if (ctx->minicpmv_version == 2) {
-                hidden_size = 4096;
-                n_head = hidden_size/d_head;
                 num_query = 96;
             }
             else if (ctx->minicpmv_version == 3) {
-                hidden_size = 3584;
-                n_head = hidden_size/d_head;
                 num_query = 64;
             }
             else if (ctx->minicpmv_version == 4) {
-                hidden_size = 3584;
-                n_head = hidden_size/d_head;
                 num_query = 64;
             }
 
@@ -1760,6 +1746,8 @@ struct clip_model_loader {
             LOG_INF("%s: projector:          %s\n", __func__, proj_type.c_str());
             LOG_INF("%s: has_llava_proj:     %d\n", __func__, ctx_clip.has_llava_projector);
             LOG_INF("%s: minicpmv_version:   %d\n", __func__, ctx_clip.minicpmv_version);
+            LOG_INF("%s: proj_scale_factor:  %d\n", __func__, hparams.proj_scale_factor);
+            LOG_INF("%s: n_wa_pattern:       %d\n", __func__, hparams.n_wa_pattern);
             LOG_INF("%s: model size:         %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0);
             LOG_INF("%s: metadata size:      %.2f MiB\n", __func__, ggml_get_mem_size(ctx_meta.get()) / 1024.0 / 1024.0);
         }
@@ -3038,15 +3026,43 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
     const int patch_size    = hparams.patch_size;
     const int num_patches   = ((image_size_width / patch_size) * (image_size_height / patch_size));
     const int num_positions = num_patches + (model.class_embedding ? 1 : 0);
-    const int pos_w = ctx->load_image_size.width / patch_size;
+    const int pos_w = ctx->load_image_size.width  / patch_size;
     const int pos_h = ctx->load_image_size.height / patch_size;
 
     const bool use_window_attn = hparams.n_wa_pattern > 0; // for qwen2.5vl
 
+    auto get_inp_tensor = [&gf](const char * name) {
+        struct ggml_tensor * inp = ggml_graph_get_tensor(gf, name);
+        if (inp == nullptr) {
+            GGML_ABORT("Failed to get tensor %s", name);
+        }
+        if (!(inp->flags & GGML_TENSOR_FLAG_INPUT)) {
+            GGML_ABORT("Tensor %s is not an input tensor", name);
+        }
+        return inp;
+    };
+
+    auto set_input_f32 = [&get_inp_tensor](const char * name, std::vector<float> & values) {
+        ggml_tensor * cur = get_inp_tensor(name);
+        GGML_ASSERT(cur->type == GGML_TYPE_F32);
+        GGML_ASSERT(ggml_nelements(cur) == (int64_t)values.size());
+        ggml_backend_tensor_set(cur, values.data(), 0, ggml_nbytes(cur));
+    };
+
+    auto set_input_i32 = [&get_inp_tensor](const char * name, std::vector<int32_t> & values) {
+        ggml_tensor * cur = get_inp_tensor(name);
+        GGML_ASSERT(cur->type == GGML_TYPE_I32);
+        GGML_ASSERT(ggml_nelements(cur) == (int64_t)values.size());
+        ggml_backend_tensor_set(cur, values.data(), 0, ggml_nbytes(cur));
+    };
+
+    // set input pixel values
     {
-        struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
-        std::vector<float> inp_data(ggml_nelements(inp_raw));
-        float * data = inp_data.data();
+        size_t nelem = 0;
+        for (const auto & img : imgs.entries) {
+            nelem += img->nx * img->ny * 3;
+        }
+        std::vector<float> inp_raw(nelem);
 
         // layout of data (note: the channel dim is unrolled to better visualize the layout):
         //
@@ -3065,7 +3081,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
             const int n = nx * ny;
 
             for (int b = 0; b < batch_size; b++) {
-                float * batch_entry = data + b * (3*n);
+                float * batch_entry = inp_raw.data() + b * (3*n);
                 for (int y = 0; y < ny; y++) {
                     for (int x = 0; x < nx; x++) {
                         size_t base_src = 3*(y * nx + x); // idx of the first channel
@@ -3077,266 +3093,207 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
                 }
             }
         }
-        ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw));
+        set_input_f32("inp_raw", inp_raw);
     }
 
-    if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
-        {
-            // inspired from siglip:
-            //    -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit
-            //    -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316
-            struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
-            std::vector<int> pos_data(ggml_nelements(positions));
-            int * data = pos_data.data();
-            int bucket_coords_h[1024];
-            int bucket_coords_w[1024];
-            for (int i = 0; i < pos_h; i++){
-                bucket_coords_h[i] = std::floor(70.0*i/pos_h);
-            }
-            for (int i = 0; i < pos_w; i++){
-                bucket_coords_w[i] = std::floor(70.0*i/pos_w);
-            }
-            for (int i = 0, id = 0; i < pos_h; i++){
-                for (int j = 0; j < pos_w; j++){
-                    data[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j];
+    // set input per projector
+    switch (ctx->proj_type) {
+        case PROJECTOR_TYPE_MINICPMV:
+            {
+                // inspired from siglip:
+                //    -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit
+                //    -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316
+                std::vector<int32_t> positions(pos_h * pos_w);
+                int bucket_coords_h[1024];
+                int bucket_coords_w[1024];
+                for (int i = 0; i < pos_h; i++){
+                    bucket_coords_h[i] = std::floor(70.0*i/pos_h);
                 }
-            }
-            ggml_backend_tensor_set(positions, data, 0, ggml_nbytes(positions));
-        }
+                for (int i = 0; i < pos_w; i++){
+                    bucket_coords_w[i] = std::floor(70.0*i/pos_w);
+                }
+                for (int i = 0, id = 0; i < pos_h; i++){
+                    for (int j = 0; j < pos_w; j++){
+                        positions[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j];
+                    }
+                }
+                set_input_i32("positions", positions);
 
-        {
-            // inspired from resampler of Qwen-VL:
-            //    -> https://huggingface.co/Qwen/Qwen-VL/tree/main
-            //    -> https://huggingface.co/Qwen/Qwen-VL/blob/0547ed36a86561e2e42fecec8fd0c4f6953e33c4/visual.py#L23
-            struct ggml_tensor * pos_embed = ggml_graph_get_tensor(gf, "pos_embed");
-            int embed_dim = 4096;
-            if (ctx->minicpmv_version == 2) {
-                embed_dim = 4096;
-            }
-            else if (ctx->minicpmv_version == 3) {
-                embed_dim = 3584;
-            }
-            else if (ctx->minicpmv_version == 4) {
-                embed_dim = 3584;
-            }
-            else {
-                GGML_ABORT("Unknown minicpmv version");
-            }
+                // inspired from resampler of Qwen-VL:
+                //    -> https://huggingface.co/Qwen/Qwen-VL/tree/main
+                //    -> https://huggingface.co/Qwen/Qwen-VL/blob/0547ed36a86561e2e42fecec8fd0c4f6953e33c4/visual.py#L23
+                int embed_dim = clip_n_mmproj_embd(ctx);
 
-            // TODO @ngxson : this is very inefficient, can we do this using ggml_sin and ggml_cos?
-            auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h));
+                // TODO @ngxson : this is very inefficient, can we do this using ggml_sin and ggml_cos?
+                auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h));
 
-            std::vector<float> pos_data(ggml_nelements(pos_embed));
-            float * data = pos_data.data();
-            for(int i = 0; i < pos_w * pos_h; ++i){
-                for(int j = 0; j < embed_dim; ++j){
-                    data[i * embed_dim + j] = pos_embed_t[i][j];
+                std::vector<float> pos_embed(embed_dim * pos_w * pos_h);
+                for(int i = 0; i < pos_w * pos_h; ++i){
+                    for(int j = 0; j < embed_dim; ++j){
+                        pos_embed[i * embed_dim + j] = pos_embed_t[i][j];
+                    }
                 }
-            }
 
-            ggml_backend_tensor_set(pos_embed, data, 0, ggml_nbytes(pos_embed));
-        }
-    }
-    else {
-        // non-minicpmv models
-
-        if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) {
-            // pw * ph = number of tokens output by ViT after apply patch merger
-            // ipw * ipw = number of vision token been processed inside ViT
-            const int merge_ratio = 2;
-            const int pw  = image_size_width  / patch_size / merge_ratio;
-            const int ph  = image_size_height / patch_size / merge_ratio;
-            const int ipw = image_size_width  / patch_size;
-            const int iph = image_size_height / patch_size;
-
-            std::vector<int> idx    (ph * pw);
-            std::vector<int> inv_idx(ph * pw);
-
-            if (use_window_attn) {
-                const int attn_window_size = 112;
-                struct ggml_tensor * window_idx     = ggml_graph_get_tensor(gf, "window_idx");
-                struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx");
-                struct ggml_tensor * window_mask    = ggml_graph_get_tensor(gf, "window_mask");
-
-                const int grid_window = attn_window_size / patch_size / merge_ratio;
-                int dst = 0;
-                // [num_vision_tokens, num_vision_tokens] attention mask tensor
-                std::vector<float> mask(pow(ipw * iph, 2), std::numeric_limits<float>::lowest());
-                int mask_row = 0;
-
-                for (int y = 0; y < ph; y += grid_window)
-                {
-                    for (int x = 0; x < pw; x += grid_window)
-                    {
-                        const int win_h = std::min(grid_window, ph - y);
-                        const int win_w = std::min(grid_window, pw - x);
-                        const int dst_0 = dst;
-                        // group all tokens belong to the same window togather (to a continue range)
-                        for (int dy = 0; dy < win_h; dy++) {
-                            for (int dx = 0; dx < win_w; dx++) {
-                                const int src = (y + dy) * pw + (x + dx);
-                                assert(src < (int)idx.size());
-                                assert(dst < (int)inv_idx.size());
-                                idx    [src] = dst;
-                                inv_idx[dst] = src;
-                                dst++;
+                set_input_f32("pos_embed", pos_embed);
+            } break;
+        case PROJECTOR_TYPE_QWEN2VL:
+            {
+                const int merge_ratio = 2;
+                const int pw = image_size_width  / patch_size;
+                const int ph = image_size_height / patch_size;
+                std::vector<int> positions(num_positions * 4);
+                int ptr = 0;
+                for (int y = 0; y < ph; y += merge_ratio) {
+                    for (int x = 0; x < pw; x += merge_ratio) {
+                        for (int dy = 0; dy < 2; dy++) {
+                            for (int dx = 0; dx < 2; dx++) {
+                                positions[                  ptr] = y + dy;
+                                positions[    num_patches + ptr] = x + dx;
+                                positions[2 * num_patches + ptr] = y + dy;
+                                positions[3 * num_patches + ptr] = x + dx;
+                                ptr++;
                             }
                         }
-
-                        for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) {
-                            int row_offset = mask_row * (ipw * iph);
-                            std::fill(
-                                mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio),
-                                mask.begin() + row_offset + (dst   * merge_ratio * merge_ratio),
-                                0.0);
-                            mask_row++;
-                        }
                     }
                 }
 
-                ggml_backend_tensor_set(window_idx,     idx.data(),     0, ggml_nbytes(window_idx));
-                ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx));
-                ggml_backend_tensor_set(window_mask,    mask.data(),    0, ggml_nbytes(window_mask));
-            } else {
-                std::iota(idx.begin(), idx.end(), 0);
-                std::iota(inv_idx.begin(), inv_idx.end(), 0);
-            }
+                set_input_i32("positions", positions);
+            } break;
+        case PROJECTOR_TYPE_QWEN25VL:
+            {
+                // pw * ph = number of tokens output by ViT after apply patch merger
+                // ipw * ipw = number of vision token been processed inside ViT
+                const int merge_ratio = 2;
+                const int pw  = image_size_width  / patch_size / merge_ratio;
+                const int ph  = image_size_height / patch_size / merge_ratio;
+                const int ipw = image_size_width  / patch_size;
+                const int iph = image_size_height / patch_size;
+
+                std::vector<int> idx    (ph * pw);
+                std::vector<int> inv_idx(ph * pw);
+
+                if (use_window_attn) {
+                    const int attn_window_size = 112;
+                    const int grid_window = attn_window_size / patch_size / merge_ratio;
+                    int dst = 0;
+                    // [num_vision_tokens, num_vision_tokens] attention mask tensor
+                    std::vector<float> mask(pow(ipw * iph, 2), std::numeric_limits<float>::lowest());
+                    int mask_row = 0;
+
+                    for (int y = 0; y < ph; y += grid_window) {
+                        for (int x = 0; x < pw; x += grid_window) {
+                            const int win_h = std::min(grid_window, ph - y);
+                            const int win_w = std::min(grid_window, pw - x);
+                            const int dst_0 = dst;
+                            // group all tokens belong to the same window togather (to a continue range)
+                            for (int dy = 0; dy < win_h; dy++) {
+                                for (int dx = 0; dx < win_w; dx++) {
+                                    const int src = (y + dy) * pw + (x + dx);
+                                    GGML_ASSERT(src < (int)idx.size());
+                                    GGML_ASSERT(dst < (int)inv_idx.size());
+                                    idx    [src] = dst;
+                                    inv_idx[dst] = src;
+                                    dst++;
+                                }
+                            }
 
-            struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
-            const int mpow = merge_ratio * merge_ratio;
-            std::vector<int> positions_data(ggml_nelements(positions));
-            int * data = positions_data.data();
+                            for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) {
+                                int row_offset = mask_row * (ipw * iph);
+                                std::fill(
+                                    mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio),
+                                    mask.begin() + row_offset + (dst   * merge_ratio * merge_ratio),
+                                    0.0);
+                                mask_row++;
+                            }
+                        }
+                    }
 
-            int ptr = 0;
-            for (int y = 0; y < iph; y += merge_ratio)
-            {
-                for (int x = 0; x < ipw; x += merge_ratio)
-                {
-                    for (int dy = 0; dy < 2; dy++) {
-                        for (int dx = 0; dx < 2; dx++) {
-                            auto remap = idx[ptr / mpow];
-                            remap = remap * mpow + (ptr % mpow);
-
-                            data[                  remap] = y + dy;
-                            data[    num_patches + remap] = x + dx;
-                            data[2 * num_patches + remap] = y + dy;
-                            data[3 * num_patches + remap] = x + dx;
-                            ptr++;
+                    set_input_i32("window_idx",     idx);
+                    set_input_i32("inv_window_idx", inv_idx);
+                    set_input_f32("window_mask",    mask);
+                } else {
+                    for (int i = 0; i < ph * pw; i++) {
+                        idx[i] = i;
+                    }
+                }
+
+                const int mpow = merge_ratio * merge_ratio;
+                std::vector<int> positions(num_positions * 4);
+
+                int ptr = 0;
+                for (int y = 0; y < iph; y += merge_ratio) {
+                    for (int x = 0; x < ipw; x += merge_ratio) {
+                        for (int dy = 0; dy < 2; dy++) {
+                            for (int dx = 0; dx < 2; dx++) {
+                                auto remap = idx[ptr / mpow];
+                                remap = (remap * mpow) + (ptr % mpow);
+
+                                positions[                  remap] = y + dy;
+                                positions[    num_patches + remap] = x + dx;
+                                positions[2 * num_patches + remap] = y + dy;
+                                positions[3 * num_patches + remap] = x + dx;
+                                ptr++;
+                            }
                         }
                     }
                 }
-            }
 
-            ggml_backend_tensor_set(positions, data, 0, ggml_nbytes(positions));
-        }
-        else if (ctx->proj_type == PROJECTOR_TYPE_GEMMA3) {
-            // do nothing
-        }
-        else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
-            // do nothing
-        }
-        else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
-            // set the 2D positions
-            int n_patches_per_col = image_size_width / patch_size;
-            std::vector<int> pos_data(num_positions);
-            struct ggml_tensor * pos;
-            // dimension H
-            pos = ggml_graph_get_tensor(gf, "pos_h");
-            for (int i = 0; i < num_positions; i++) {
-                pos_data[i] = i / n_patches_per_col;
-            }
-            ggml_backend_tensor_set(pos, pos_data.data(), 0, ggml_nbytes(pos));
-            // dimension W
-            pos = ggml_graph_get_tensor(gf, "pos_w");
-            for (int i = 0; i < num_positions; i++) {
-                pos_data[i] = i % n_patches_per_col;
-            }
-            ggml_backend_tensor_set(pos, pos_data.data(), 0, ggml_nbytes(pos));
-        }
-        else {
+                set_input_i32("positions", positions);
+            } break;
+        case PROJECTOR_TYPE_PIXTRAL:
+            {
+                // set the 2D positions
+                int n_patches_per_col = image_size_width / patch_size;
+                std::vector<int> pos_data(num_positions);
+                // dimension H
+                for (int i = 0; i < num_positions; i++) {
+                    pos_data[i] = i / n_patches_per_col;
+                }
+                set_input_i32("pos_h", pos_data);
+                // dimension W
+                for (int i = 0; i < num_positions; i++) {
+                    pos_data[i] = i % n_patches_per_col;
+                }
+                set_input_i32("pos_w", pos_data);
+            } break;
+        case PROJECTOR_TYPE_GLM_EDGE:
+        {
             // llava and other models
-            struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
-
-            int* positions_data = (int*)malloc(ggml_nbytes(positions));
+            std::vector<int32_t> positions(num_positions);
             for (int i = 0; i < num_positions; i++) {
-                positions_data[i] = i;
+                positions[i] = i;
             }
-            ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
-            free(positions_data);
+            set_input_i32("positions", positions);
+        } break;
+        case PROJECTOR_TYPE_MLP:
+        case PROJECTOR_TYPE_MLP_NORM:
+        case PROJECTOR_TYPE_LDP:
+        case PROJECTOR_TYPE_LDPV2:
+            {
+                // llava and other models
+                std::vector<int32_t> positions(num_positions);
+                for (int i = 0; i < num_positions; i++) {
+                    positions[i] = i;
+                }
+                set_input_i32("positions", positions);
 
-            if (ctx->proj_type != PROJECTOR_TYPE_GLM_EDGE) {
-                struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches");
                 // The patches vector is used to get rows to index into the embeds with;
                 // we should skip dim 0 only if we have CLS to avoid going out of bounds
                 // when retrieving the rows.
                 int patch_offset = model.class_embedding ? 1 : 0;
-                int* patches_data = (int*)malloc(ggml_nbytes(patches));
+                std::vector<int32_t> patches(num_patches);
                 for (int i = 0; i < num_patches; i++) {
-                    patches_data[i] = i + patch_offset;
+                    patches[i] = i + patch_offset;
                 }
-                ggml_backend_tensor_set(patches, patches_data, 0, ggml_nbytes(patches));
-                free(patches_data);
-            }
-        }
-    }
-
-    if (use_window_attn && (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL)) {
-        struct ggml_tensor * window_idx = ggml_graph_get_tensor(gf, "window_idx");
-        struct ggml_tensor * inv_window_idx = ggml_graph_get_tensor(gf, "inv_window_idx");
-        struct ggml_tensor * window_mask = ggml_graph_get_tensor(gf, "window_mask");
-
-        const int merge_ratio = 2;
-        const int attn_window_size = 112;
-        const int pw = image_size_width / patch_size / merge_ratio;
-        const int ph = image_size_height / patch_size / merge_ratio;
-        const int grid_window = attn_window_size / patch_size / merge_ratio;
-        const int ipw = image_size_width / patch_size;
-        const int iph = image_size_height / patch_size;
-        /*
-        pw * ph = number of tokens output by ViT after apply patch merger
-        ipw * ipw = number of vision token been processed inside ViT
-        */
-
-        std::vector<int> idx(ph * pw);
-        std::vector<int> inv_idx(ph * pw);
-        int dst = 0;
-        // [num_vision_tokens, num_vision_tokens] attention mask tensor
-        std::vector<float> mask(pow(ipw * iph, 2), std::numeric_limits<float>::lowest());
-        int mask_row = 0;
-
-        for (int y = 0; y < ph; y+=grid_window)
-        {
-            for (int x = 0; x < pw; x+=grid_window)
+                set_input_i32("patches", patches);
+            } break;
+        case PROJECTOR_TYPE_GEMMA3:
+        case PROJECTOR_TYPE_IDEFICS3:
             {
-                const int win_h = std::min(grid_window, ph - y);
-                const int win_w = std::min(grid_window, pw - x);
-                const int dst_0 = dst;
-                // group all tokens belong to the same window togather (to a continue range)
-                for (int dy = 0; dy < win_h; dy++) {
-                    for (int dx = 0; dx < win_w; dx++) {
-                        const int src = (y + dy) * pw + (x + dx);
-                        assert(src < (int)idx.size());
-                        assert(dst < (int)inv_idx.size());
-                        idx[src] = dst;
-                        inv_idx[dst] = src;
-                        dst++;
-                    }
-                }
-
-                for (int r=0; r < win_h * win_w * merge_ratio * merge_ratio; r++) {
-                    int row_offset = mask_row * (ipw * iph);
-                    std::fill(
-                        mask.begin() + row_offset + (dst_0 * merge_ratio * merge_ratio),
-                        mask.begin() + row_offset + (dst   * merge_ratio * merge_ratio),
-                        0.0);
-                    mask_row++;
-                }
-            }
-        }
-
-        ggml_backend_tensor_set(window_idx, idx.data(), 0, ggml_nbytes(window_idx));
-        ggml_backend_tensor_set(inv_window_idx, inv_idx.data(), 0, ggml_nbytes(inv_window_idx));
-        ggml_backend_tensor_set(window_mask, mask.data(), 0, ggml_nbytes(window_mask));
+                // do nothing
+            } break;
+        default:
+            GGML_ABORT("Unknown projector type");
     }
 
     ggml_backend_cpu_set_n_threads(ctx->backend_cpu, n_threads);
@@ -3537,7 +3494,7 @@ bool clip_is_glm(const struct clip_ctx * ctx) {
 }
 
 bool clip_is_qwen2vl(const struct clip_ctx * ctx) {
-    return ctx->proj_type == PROJECTOR_TYPE_QWEN2VL;
+    return ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL;
 }
 
 bool clip_is_llava(const struct clip_ctx * ctx) {