]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
mtmd: refactor preprocessing + support max/min pixels (#16878)
authorXuan-Son Nguyen <redacted>
Sat, 1 Nov 2025 14:51:36 +0000 (15:51 +0100)
committerGitHub <redacted>
Sat, 1 Nov 2025 14:51:36 +0000 (15:51 +0100)
* mtmd: refactor preprocessing + support max/min pixels

* fix mlp type

* implement mix/max pixels

* improve hparams

* better image preproc for qwen

* fix

* fix out of bound composite

* fix (2)

* fix token calculation

* get_merge_kernel_size()

* fix llama4 and lfm2

* gonna fix them all

* use simple resize for qwen

* qwen: increase min tokens

* no resize if dst size == src size

* restore to initial min/max tokens value for qwen

tools/mtmd/clip-impl.h
tools/mtmd/clip.cpp

index 311a4c9086a537e00ea2c9b29e47340a4e663a32..c7e9498349c1b5476793dbec4628e43883564e37 100644 (file)
@@ -154,8 +154,8 @@ enum projector_type {
     PROJECTOR_TYPE_LFM2,
     PROJECTOR_TYPE_KIMIVL,
     PROJECTOR_TYPE_LIGHTONOCR,
-    PROJECTOR_TYPE_UNKNOWN,
     PROJECTOR_TYPE_COGVLM,
+    PROJECTOR_TYPE_UNKNOWN,
 };
 
 static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
index b312fda637f3bff6e85046f371b892a2c736f9ed..dcfdb49600b6c3399fd796347fd0477424211234 100644 (file)
@@ -171,8 +171,10 @@ struct clip_hparams {
     int32_t n_head;
     int32_t n_layer;
     // idefics3
-    int32_t preproc_image_size = 0; // aka max_dimension
-    int32_t proj_scale_factor = 0;
+    int32_t image_longest_edge = 0;
+    int32_t image_min_pixels = 0;
+    int32_t image_max_pixels = 0;
+    int32_t n_merge = 0; // number of patch merges **per-side**
 
     float image_mean[3];
     float image_std[3];
@@ -194,7 +196,6 @@ struct clip_hparams {
     std::unordered_set<int32_t> vision_feature_layer;
     int32_t attn_window_size = 0;
     int32_t n_wa_pattern = 0;
-    int32_t spatial_merge_size = 0;
 
     // audio
     int32_t n_mel_bins = 0; // whisper preprocessor
@@ -204,6 +205,21 @@ struct clip_hparams {
     bool has_llava_projector = false;
     int minicpmv_version = 0;
     int32_t minicpmv_query_num = 0;         // MiniCPM-V query number
+
+    void set_limit_image_tokens(int n_tokens_min, int n_tokens_max) {
+        const int cur_merge = n_merge == 0 ? 1 : n_merge;
+        const int patch_area = patch_size * patch_size * cur_merge * cur_merge;
+        image_min_pixels = n_tokens_min * patch_area;
+        image_max_pixels = n_tokens_max * patch_area;
+        warmup_image_size = static_cast<int>(std::sqrt(image_max_pixels));
+    }
+
+    void set_warmup_n_tokens(int n_tokens) {
+        int n_tok_per_side = static_cast<int>(std::sqrt(n_tokens));
+        GGML_ASSERT(n_tok_per_side * n_tok_per_side == n_tokens && "n_tokens must be n*n");
+        const int cur_merge = n_merge == 0 ? 1 : n_merge;
+        warmup_image_size = n_tok_per_side * patch_size * cur_merge;
+    }
 };
 
 struct clip_layer {
@@ -532,7 +548,7 @@ struct clip_graph {
             const int batch_size = 1;
             GGML_ASSERT(n_patches_x == n_patches_y);
             const int patches_per_image = n_patches_x;
-            const int kernel_size = hparams.proj_scale_factor;
+            const int kernel_size = hparams.n_merge;
 
             cur = ggml_transpose(ctx0, cur);
             cur = ggml_cont_4d(ctx0, cur, patches_per_image, patches_per_image, n_embd, batch_size);
@@ -554,13 +570,13 @@ struct clip_graph {
         } else if (ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3) {
             // pixel_shuffle
             // https://github.com/huggingface/transformers/blob/0a950e0bbe1ed58d5401a6b547af19f15f0c195e/src/transformers/models/idefics3/modeling_idefics3.py#L578
-            const int scale_factor = model.hparams.proj_scale_factor;
+            const int scale_factor = model.hparams.n_merge;
             cur = build_patch_merge_permute(cur, scale_factor);
             cur = ggml_mul_mat(ctx0, model.projection, cur);
 
         } else if (ctx->proj_type() == PROJECTOR_TYPE_LFM2) {
             // pixel unshuffle block
-            const int scale_factor = model.hparams.proj_scale_factor;
+            const int scale_factor = model.hparams.n_merge;
             cur = build_patch_merge_permute(cur, scale_factor);
 
             // projection
@@ -584,7 +600,7 @@ struct clip_graph {
     }
 
     ggml_cgraph * build_pixtral() {
-        const int n_merge = hparams.spatial_merge_size;
+        const int n_merge = hparams.n_merge;
 
         // 2D input positions
         ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches);
@@ -610,7 +626,7 @@ struct clip_graph {
         // mistral small 3.1 patch merger
         // ref: https://github.com/huggingface/transformers/blob/7a3e208892c06a5e278144eaf38c8599a42f53e7/src/transformers/models/mistral3/modeling_mistral3.py#L67
         if (model.mm_patch_merger_w) {
-            GGML_ASSERT(hparams.spatial_merge_size > 0);
+            GGML_ASSERT(hparams.n_merge > 0);
 
             cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.mm_input_norm_w);
 
@@ -926,7 +942,7 @@ struct clip_graph {
 
         // deepstack features (stack along the feature dimension), [n_embd * len(deepstack_layers), n_patches_x * n_patches_y, batch_size]
         ggml_tensor * deepstack_features = nullptr;
-        const int merge_factor = hparams.spatial_merge_size > 0 ? hparams.spatial_merge_size * hparams.spatial_merge_size : 4; // default 2x2=4 for qwen3vl
+        const int merge_factor = hparams.n_merge > 0 ? hparams.n_merge * hparams.n_merge : 4; // default 2x2=4 for qwen3vl
 
         // loop over layers
         for (int il = 0; il < n_layer; il++) {
@@ -1149,7 +1165,7 @@ struct clip_graph {
 
         // pixel shuffle
         {
-            const int scale_factor = model.hparams.proj_scale_factor;
+            const int scale_factor = model.hparams.n_merge;
             const int bsz    = 1; // batch size, always 1 for now since we don't support batching
             const int height = n_patches_y;
             const int width  = n_patches_x;
@@ -1239,7 +1255,7 @@ struct clip_graph {
         // based on Llama4VisionPixelShuffleMLP
         // https://github.com/huggingface/transformers/blob/2932f318a20d9e54cc7aea052e040164d85de7d6/src/transformers/models/llama4/modeling_llama4.py#L1151
         {
-            const int scale_factor = model.hparams.proj_scale_factor;
+            const int scale_factor = model.hparams.n_merge;
             const int bsz = 1; // batch size, always 1 for now since we don't support batching
             GGML_ASSERT(scale_factor > 0);
             GGML_ASSERT(n_patches_x == n_patches_y); // llama4 only supports square images
@@ -1311,7 +1327,7 @@ struct clip_graph {
 
         {
             // patch_merger
-            const int scale_factor = model.hparams.proj_scale_factor;
+            const int scale_factor = model.hparams.n_merge;
             cur = build_patch_merge_permute(cur, scale_factor);
 
             // projection norm
@@ -2577,7 +2593,6 @@ struct clip_model_loader {
 
             if (is_vision) {
                 get_u32(KEY_IMAGE_SIZE, hparams.image_size);
-                get_u32(KEY_PREPROC_IMAGE_SIZE, hparams.preproc_image_size, false);
                 get_u32(KEY_PATCH_SIZE, hparams.patch_size);
                 get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false);
                 get_i32(KEY_MINICPMV_VERSION, hparams.minicpmv_version, false); // legacy
@@ -2686,65 +2701,68 @@ struct clip_model_loader {
                             hparams.minicpmv_version = 2; // default to 2 if not set
                         }
                     } break;
+                case PROJECTOR_TYPE_INTERNVL:
+                    {
+                        get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false);
+                    } break;
                 case PROJECTOR_TYPE_IDEFICS3:
+                    {
+                        get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false);
+                        get_u32(KEY_PREPROC_IMAGE_SIZE, hparams.image_longest_edge, false);
+                    } break;
                 case PROJECTOR_TYPE_LFM2:
-                case PROJECTOR_TYPE_INTERNVL:
                     {
-                        get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
+                        get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false);
+                        // ref: https://huggingface.co/LiquidAI/LFM2-VL-3B/blob/main/preprocessor_config.json
+                        hparams.set_limit_image_tokens(64, 256);
                     } break;
                 case PROJECTOR_TYPE_PIXTRAL:
                 case PROJECTOR_TYPE_LIGHTONOCR:
                     {
+                        // ref: https://huggingface.co/mistral-community/pixtral-12b/blob/main/preprocessor_config.json
+                        // TODO: verify the image_min_tokens
                         hparams.rope_theta = 10000.0f;
-                        hparams.warmup_image_size = hparams.patch_size * 8;
-                        // Mistral Small 2506 needs 1024x1024 image size cap to prevent OOM
-                        // ref: https://github.com/ggml-org/llama.cpp/issues/14310
-                        hparams.image_size = 1024;
-                        get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false);
+                        get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.n_merge, false);
+                        hparams.set_limit_image_tokens(8, 1024);
+                        hparams.set_warmup_n_tokens(256); // avoid OOM on warmup
                     } break;
                 case PROJECTOR_TYPE_KIMIVL:
                     {
                         hparams.rope_theta = 10000.0f;
-                        hparams.warmup_image_size = hparams.patch_size * 8;
-                        get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
+                        get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false);
+                        // TODO: check kimivl preprocessor for exact values
+                        hparams.set_limit_image_tokens(8, 1024);
+                        hparams.set_warmup_n_tokens(256); // avoid OOM on warmup
                     } break;
                 case PROJECTOR_TYPE_GEMMA3:
                     {
                         // default value (used by all model sizes in gemma 3 family)
                         // number of patches for each **side** is reduced by a factor of 4
-                        hparams.proj_scale_factor = 4;
+                        hparams.n_merge = 4;
                         // test model (tinygemma3) has a different value, we optionally read it
-                        get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
+                        get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false);
                     } break;
                 case PROJECTOR_TYPE_QWEN2VL:
-                    {
-                        // max image size = sqrt(max_pixels) = 3584
-                        // ref: https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/blob/main/preprocessor_config.json
-                        // however, the model use unreasonable memory past 1024 size, we force it to 1024 otherwise it's unusable
-                        // ref: https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct/discussions/10
-                        hparams.image_size = 1024;
-                        hparams.warmup_image_size = hparams.patch_size * 8;
-                    } break;
                 case PROJECTOR_TYPE_QWEN25VL:
-                    {
-                        // max image size = sqrt(max_pixels)
-                        // https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct/blob/main/preprocessor_config.json
-                        // however, the model use unreasonable memory past 1024 size, we force it to 1024 otherwise it's unusable
-                        // ref: https://huggingface.co/Qwen/Qwen2-VL-2B-Instruct/discussions/10
-                        hparams.image_size = 1024;
-                        hparams.warmup_image_size = hparams.patch_size * 8;
-                        get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern);
-                    } break;
                 case PROJECTOR_TYPE_QWEN3VL:
                     {
-                        hparams.image_size = 1024; // still need this?
-                        hparams.warmup_image_size = hparams.patch_size * 8;
-                        get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.spatial_merge_size, false);
+                        hparams.n_merge = 2; // default value for Qwen 2 and 2.5
+                        get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.n_merge, false);
+                        get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern, model.proj_type == PROJECTOR_TYPE_QWEN25VL); // only 2.5 requires it
+                        // ref: https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct/blob/main/preprocessor_config.json
+                        // the actual max limit is 12845056/14/14/2/2/4 = 4096 tokens
+                        // but we set a lower value to avoid OOM
+                        // TODO: make it configurable by user
+                        // TODO (2): bbox coordinates become inaccurate with small number of tokens,
+                        //           therefore we need to increase the min_tokens
+                        //           see: https://github.com/ggml-org/llama.cpp/issues/16842#issuecomment-3475144858
+                        hparams.set_limit_image_tokens(8, 2048);
+                        hparams.set_warmup_n_tokens(256); // avoid OOM on warmup
                     } break;
                 case PROJECTOR_TYPE_LLAMA4:
                     {
                         hparams.rope_theta = 10000.0f;
-                        get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor);
+                        get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false);
                         set_llava_uhd_res_candidates(model, 3);
                     } break;
                 case PROJECTOR_TYPE_ULTRAVOX:
@@ -2777,10 +2795,13 @@ struct clip_model_loader {
                 LOG_INF("%s: patch_size:         %d\n", __func__, hparams.patch_size);
                 LOG_INF("%s: has_llava_proj:     %d\n", __func__, hparams.has_llava_projector);
                 LOG_INF("%s: minicpmv_version:   %d\n", __func__, hparams.minicpmv_version);
-                LOG_INF("%s: proj_scale_factor:  %d\n", __func__, hparams.proj_scale_factor);
+                LOG_INF("%s: n_merge:            %d\n", __func__, hparams.n_merge);
                 LOG_INF("%s: n_wa_pattern:       %d\n", __func__, hparams.n_wa_pattern);
-                if (hparams.spatial_merge_size > 0) {
-                    LOG_INF("%s: spatial_merge_size: %d\n", __func__, hparams.spatial_merge_size);
+                if (hparams.image_min_pixels > 0) {
+                    LOG_INF("%s: image_min_pixels:   %d\n", __func__, hparams.image_min_pixels);
+                }
+                if (hparams.image_max_pixels > 0) {
+                    LOG_INF("%s: image_max_pixels:   %d\n", __func__, hparams.image_max_pixels);
                 }
             } else if (is_audio) {
                 LOG_INF("\n--- audio hparams ---\n");
@@ -3181,9 +3202,11 @@ struct clip_model_loader {
         if (ctx_clip.model.modality == CLIP_MODALITY_VISION) {
             img->nx = hparams.warmup_image_size;
             img->ny = hparams.warmup_image_size;
+            LOG_INF("%s: warmup with image size = %d x %d\n", __func__, img->nx, img->ny);
         } else {
             img->nx = hparams.warmup_audio_size;
             img->ny = hparams.n_mel_bins;
+            LOG_INF("%s: warmup with audio size = %d\n", __func__, img->nx);
         }
         batch.entries.push_back(std::move(img));
 
@@ -3399,9 +3422,169 @@ static void normalize_image_u8_to_f32(const clip_image_u8 & src, clip_image_f32
 
 // set of tools to manupulate images
 // in the future, we can have HW acceleration by allowing this struct to access 3rd party lib like imagick or opencv
-struct image_manipulation {
+struct img_tool {
+    enum resize_algo {
+        RESIZE_ALGO_BILINEAR,
+        RESIZE_ALGO_BICUBIC,
+        // RESIZE_ALGO_LANCZOS, // TODO
+    };
+
+    static void resize(
+            const clip_image_u8 & src,
+            clip_image_u8 & dst,
+            const clip_image_size & target_resolution,
+            resize_algo algo,
+            bool add_padding = true, // TODO: define the behavior for add_padding = false
+            std::array<uint8_t, 3> pad_color = {0, 0, 0}) {
+        dst.nx = target_resolution.width;
+        dst.ny = target_resolution.height;
+        dst.buf.resize(3 * dst.nx * dst.ny);
+
+        if (dst.nx == src.nx && dst.ny == src.ny) {
+            // no resize needed, simple copy
+            dst.buf = src.buf;
+            return;
+        }
+
+        if (!add_padding) {
+            // direct resize
+            switch (algo) {
+                case RESIZE_ALGO_BILINEAR:
+                    resize_bilinear(src, dst, target_resolution.width, target_resolution.height);
+                    break;
+                case RESIZE_ALGO_BICUBIC:
+                    resize_bicubic(src, dst, target_resolution.width, target_resolution.height);
+                    break;
+                default:
+                    throw std::runtime_error("Unsupported resize algorithm");
+            }
+        } else {
+            // resize with padding
+            clip_image_u8 resized_image;
+            float scale_w = static_cast<float>(target_resolution.width) / src.nx;
+            float scale_h = static_cast<float>(target_resolution.height) / src.ny;
+            float scale = std::min(scale_w, scale_h);
+            int new_width  = std::min(static_cast<int>(std::ceil(src.nx * scale)), target_resolution.width);
+            int new_height = std::min(static_cast<int>(std::ceil(src.ny * scale)), target_resolution.height);
+
+            switch (algo) {
+                case RESIZE_ALGO_BILINEAR:
+                    resize_bilinear(src, resized_image, new_width, new_height);
+                    break;
+                case RESIZE_ALGO_BICUBIC:
+                    resize_bicubic(src, resized_image, new_width, new_height);
+                    break;
+                default:
+                    throw std::runtime_error("Unsupported resize algorithm");
+            }
+
+            // fill dst with pad_color
+            fill(dst, pad_color);
+
+            int offset_x = (target_resolution.width  - new_width)  / 2;
+            int offset_y = (target_resolution.height - new_height) / 2;
+
+            composite(dst, resized_image, offset_x, offset_y);
+        }
+    }
+
+    static void crop(const clip_image_u8 & image, clip_image_u8 & dst, int x, int y, int w, int h) {
+        dst.nx = w;
+        dst.ny = h;
+        dst.buf.resize(3 * w * h);
+
+        for (int i = 0; i < h; ++i) {
+            for (int j = 0; j < w; ++j) {
+                int src_idx = 3 * ((y + i)*image.nx + (x + j));
+                int dst_idx = 3 * (i*w + j);
+                dst.buf[dst_idx]     = image.buf[src_idx];
+                dst.buf[dst_idx + 1] = image.buf[src_idx + 1];
+                dst.buf[dst_idx + 2] = image.buf[src_idx + 2];
+            }
+        }
+    }
+
+    // calculate the size of the **resized** image, while preserving the aspect ratio
+    // the calculated size will be aligned to the nearest multiple of align_size
+    // if H or W size is larger than longest_edge, it will be resized to longest_edge
+    static clip_image_size calc_size_preserved_ratio(const clip_image_size & inp_size, const int align_size, const int longest_edge) {
+        GGML_ASSERT(align_size > 0);
+        if (inp_size.width <= 0 || inp_size.height <= 0 || longest_edge <= 0) {
+            return {0, 0};
+        }
+
+        float scale = std::min(static_cast<float>(longest_edge) / inp_size.width,
+                               static_cast<float>(longest_edge) / inp_size.height);
+
+        float target_width_f  = static_cast<float>(inp_size.width)  * scale;
+        float target_height_f = static_cast<float>(inp_size.height) * scale;
+
+        auto ceil_by_factor = [f = align_size](float x) { return static_cast<int>(std::ceil(x / static_cast<float>(f))) * f; };
+        int aligned_width  = ceil_by_factor(target_width_f);
+        int aligned_height = ceil_by_factor(target_height_f);
+
+        return {aligned_width, aligned_height};
+    }
+
+    // calculate the size of the **resized** image, while preserving the aspect ratio
+    // the calculated size will have min_pixels <= W*H <= max_pixels
+    // this is referred as "smart_resize" in transformers code
+    static clip_image_size calc_size_preserved_ratio(const clip_image_size & inp_size, const int align_size, const int min_pixels, const int max_pixels) {
+        GGML_ASSERT(align_size > 0);
+        const int width  = inp_size.width;
+        const int height = inp_size.height;
+
+        auto ceil_by_factor  = [f = align_size](float x) { return static_cast<int>(std::ceil(x / static_cast<float>(f))) * f; };
+        auto floor_by_factor = [f = align_size](float x) { return static_cast<int>(std::floor(x / static_cast<float>(f))) * f; };
+
+        // always align up first
+        int h_bar = std::max(align_size, ceil_by_factor(height));
+        int w_bar = std::max(align_size, ceil_by_factor(width));
+
+        if (h_bar * w_bar > max_pixels) {
+            const auto beta = std::sqrt(static_cast<float>(height * width) / max_pixels);
+            h_bar = std::max(align_size, floor_by_factor(height / beta));
+            w_bar = std::max(align_size, floor_by_factor(width  / beta));
+        } else if (h_bar * w_bar < min_pixels) {
+            const auto beta = std::sqrt(static_cast<float>(min_pixels) / (height * width));
+            h_bar = ceil_by_factor(height * beta);
+            w_bar = ceil_by_factor(width * beta);
+        }
+
+        return {w_bar, h_bar};
+    }
+
+    // draw src image into dst image at offset (offset_x, offset_y)
+    static void composite(clip_image_u8 & dst, const clip_image_u8 & src, int offset_x, int offset_y) {
+        for (int y = 0; y < src.ny; ++y) {
+            for (int x = 0; x < src.nx; ++x) {
+                int dx = x + offset_x;
+                int dy = y + offset_y;
+                // skip pixels that would be out of bounds in the destination
+                if (dx < 0 || dy < 0 || dx >= dst.nx || dy >= dst.ny) {
+                    continue;
+                }
+                size_t dst_idx = 3 * (static_cast<size_t>(dy) * dst.nx + static_cast<size_t>(dx));
+                size_t src_idx = 3 * (static_cast<size_t>(y) * src.nx + static_cast<size_t>(x));
+                dst.buf[dst_idx + 0] = src.buf[src_idx + 0];
+                dst.buf[dst_idx + 1] = src.buf[src_idx + 1];
+                dst.buf[dst_idx + 2] = src.buf[src_idx + 2];
+            }
+        }
+    }
+
+    // fill the image with a solid color
+    static void fill(clip_image_u8 & img, const std::array<uint8_t, 3> & color) {
+        for (size_t i = 0; i < img.buf.size(); i += 3) {
+            img.buf[i]     = color[0];
+            img.buf[i + 1] = color[1];
+            img.buf[i + 2] = color[2];
+        }
+    }
+
+private:
     // Bilinear resize function
-    static void bilinear_resize(const clip_image_u8& src, clip_image_u8& dst, int target_width, int target_height) {
+    static void resize_bilinear(const clip_image_u8 & src, clip_image_u8 & dst, int target_width, int target_height) {
         dst.nx = target_width;
         dst.ny = target_height;
         dst.buf.resize(3 * target_width * target_height);
@@ -3437,7 +3620,7 @@ struct image_manipulation {
 
     // Bicubic resize function
     // part of image will be cropped if the aspect ratio is different
-    static bool bicubic_resize(const clip_image_u8 & img, clip_image_u8 & dst, int target_width, int target_height) {
+    static bool resize_bicubic(const clip_image_u8 & img, clip_image_u8 & dst, int target_width, int target_height) {
         const int nx = img.nx;
         const int ny = img.ny;
 
@@ -3500,93 +3683,6 @@ struct image_manipulation {
         return true;
     }
 
-    // llava-1.6 type of resize_and_pad
-    // if the ratio is not 1:1, padding with pad_color will be applied
-    // pad_color is single channel, default is 0 (black)
-    static void resize_and_pad_image(const clip_image_u8 & image, clip_image_u8 & dst, const clip_image_size & target_resolution, std::array<uint8_t, 3> pad_color = {0, 0, 0}) {
-        int target_width  = target_resolution.width;
-        int target_height = target_resolution.height;
-
-        float scale_w = static_cast<float>(target_width) / image.nx;
-        float scale_h = static_cast<float>(target_height) / image.ny;
-
-        int new_width, new_height;
-
-        if (scale_w < scale_h) {
-            new_width  = target_width;
-            new_height = std::min(static_cast<int>(std::ceil(image.ny * scale_w)), target_height);
-        } else {
-            new_height = target_height;
-            new_width  = std::min(static_cast<int>(std::ceil(image.nx * scale_h)), target_width);
-        }
-
-        clip_image_u8 resized_image;
-        bicubic_resize(image, resized_image, new_width, new_height);
-
-        clip_image_u8 padded_image;
-        padded_image.nx = target_width;
-        padded_image.ny = target_height;
-        padded_image.buf.resize(3 * target_width * target_height);
-
-        // Fill the padded image with the fill color
-        for (size_t i = 0; i < padded_image.buf.size(); i += 3) {
-            padded_image.buf[i]     = pad_color[0];
-            padded_image.buf[i + 1] = pad_color[1];
-            padded_image.buf[i + 2] = pad_color[2];
-        }
-
-        // Calculate padding offsets
-        int pad_x = (target_width  - new_width)  / 2;
-        int pad_y = (target_height - new_height) / 2;
-
-        // Copy the resized image into the center of the padded buffer
-        for (int y = 0; y < new_height; ++y) {
-            for (int x = 0; x < new_width; ++x) {
-                for (int c = 0; c < 3; ++c) {
-                    padded_image.buf[3 * ((y + pad_y) * target_width + (x + pad_x)) + c] = resized_image.buf[3 * (y * new_width + x) + c];
-                }
-            }
-        }
-        dst = std::move(padded_image);
-    }
-
-    static void crop_image(const clip_image_u8 & image, clip_image_u8 & dst, int x, int y, int w, int h) {
-        dst.nx = w;
-        dst.ny = h;
-        dst.buf.resize(3 * w * h);
-
-        for (int i = 0; i < h; ++i) {
-            for (int j = 0; j < w; ++j) {
-                int src_idx = 3 * ((y + i)*image.nx + (x + j));
-                int dst_idx = 3 * (i*w + j);
-                dst.buf[dst_idx]     = image.buf[src_idx];
-                dst.buf[dst_idx + 1] = image.buf[src_idx + 1];
-                dst.buf[dst_idx + 2] = image.buf[src_idx + 2];
-            }
-        }
-    }
-
-    // calculate the size of the **resized** image, while preserving the aspect ratio
-    // the calculated size will be aligned to the nearest multiple of align_size
-    // if H or W size is larger than max_dimension, it will be resized to max_dimension
-    static clip_image_size calc_size_preserved_ratio(const clip_image_size & inp_size, const int align_size, const int max_dimension) {
-        if (inp_size.width <= 0 || inp_size.height <= 0 || align_size <= 0 || max_dimension <= 0) {
-            return {0, 0};
-        }
-
-        float scale = std::min(static_cast<float>(max_dimension) / inp_size.width,
-                               static_cast<float>(max_dimension) / inp_size.height);
-
-        float target_width_f  = static_cast<float>(inp_size.width)  * scale;
-        float target_height_f = static_cast<float>(inp_size.height) * scale;
-
-        int aligned_width  = CLIP_ALIGN((int)target_width_f,  align_size);
-        int aligned_height = CLIP_ALIGN((int)target_height_f, align_size);
-
-        return {aligned_width, aligned_height};
-    }
-
-private:
     static inline int clip(int x, int lower, int upper) {
         return std::max(lower, std::min(x, upper));
     }
@@ -3735,10 +3831,11 @@ struct llava_uhd {
 
     static std::vector<clip_image_u8_ptr> slice_image(const clip_image_u8 * img, const slice_instructions & inst) {
         std::vector<clip_image_u8_ptr> output;
+        img_tool::resize_algo interpolation = img_tool::RESIZE_ALGO_BILINEAR; // TODO: make it configurable
 
         // resize to overview size
         clip_image_u8_ptr resized_img(clip_image_u8_init());
-        image_manipulation::resize_and_pad_image(*img, *resized_img, inst.overview_size);
+        img_tool::resize(*img, *resized_img, inst.overview_size, interpolation);
         output.push_back(std::move(resized_img));
         if (inst.slices.empty()) {
             // no slices, just return the resized image
@@ -3748,9 +3845,11 @@ struct llava_uhd {
         // resize to refined size
         clip_image_u8_ptr refined_img(clip_image_u8_init());
         if (inst.padding_refined) {
-            image_manipulation::resize_and_pad_image(*img, *refined_img, inst.refined_size);
+            img_tool::resize(*img, *refined_img, inst.refined_size, interpolation);
         } else {
-            image_manipulation::bilinear_resize(*img, *refined_img, inst.refined_size.width, inst.refined_size.height);
+            // only algo bicubic preserves the ratio; old models rely on this behavior
+            // TODO: do we need to support other algos here?
+            img_tool::resize(*img, *refined_img, inst.refined_size, img_tool::RESIZE_ALGO_BICUBIC, false);
         }
 
         // create slices
@@ -3761,7 +3860,7 @@ struct llava_uhd {
             int h = slice.size.height;
 
             clip_image_u8_ptr img_slice(clip_image_u8_init());
-            image_manipulation::crop_image(*refined_img, *img_slice, x, y, w, h);
+            img_tool::crop(*refined_img, *img_slice, x, y, w, h);
             output.push_back(std::move(img_slice));
         }
 
@@ -3896,208 +3995,211 @@ private:
 // res_imgs memory is being allocated here, previous allocations will be freed if found
 bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) {
     clip_image_size original_size{img->nx, img->ny};
-    bool pad_to_square = true;
     auto & params = ctx->model.hparams;
-    // The model config actually contains all we need to decide on how to preprocess, here we automatically switch to the new llava-1.6 preprocessing
-    if (params.mm_patch_merge_type == PATCH_MERGE_SPATIAL_UNPAD) {
-        pad_to_square = false;
-    }
 
-    if (clip_is_minicpmv(ctx)) {
-        auto const inst = llava_uhd::get_slice_instructions(ctx, original_size);
-        std::vector<clip_image_u8_ptr> imgs = llava_uhd::slice_image(img, inst);
+    switch (ctx->proj_type()) {
+        case PROJECTOR_TYPE_MINICPMV:
+            {
+                auto const inst = llava_uhd::get_slice_instructions(ctx, original_size);
+                std::vector<clip_image_u8_ptr> imgs = llava_uhd::slice_image(img, inst);
+
+                for (size_t i = 0; i < imgs.size(); ++i) {
+                    // clip_image_save_to_bmp(*imgs[i], "slice_" + std::to_string(i) + ".bmp");
+                    clip_image_f32_ptr res(clip_image_f32_init());
+                    normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std);
+                    res_imgs->entries.push_back(std::move(res));
+                }
 
-        for (size_t i = 0; i < imgs.size(); ++i) {
-            // clip_image_save_to_bmp(*imgs[i], "slice_" + std::to_string(i) + ".bmp");
-            clip_image_f32_ptr res(clip_image_f32_init());
-            normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std);
-            res_imgs->entries.push_back(std::move(res));
-        }
+                res_imgs->grid_x = inst.grid_size.width;
+                res_imgs->grid_y = inst.grid_size.height;
+            } break;
 
-        res_imgs->grid_x = inst.grid_size.width;
-        res_imgs->grid_y = inst.grid_size.height;
-        return true;
+        case PROJECTOR_TYPE_QWEN2VL:
+        case PROJECTOR_TYPE_QWEN25VL:
+        case PROJECTOR_TYPE_QWEN3VL:
+            {
+                // step 1: make a blank canvas which aligns to the grid
+                clip_image_u8 resized;
+                const clip_image_size new_size = img_tool::calc_size_preserved_ratio(
+                    original_size,
+                    params.patch_size * 2,
+                    params.image_min_pixels,
+                    params.image_max_pixels);
+                img_tool::resize(*img, resized, new_size, img_tool::RESIZE_ALGO_BILINEAR, false);
+                // clip_image_save_to_bmp(resized, "preproc.bmp");
+                clip_image_f32_ptr img_f32(clip_image_f32_init());
+                // clip_image_f32_ptr res(clip_image_f32_init());
+                normalize_image_u8_to_f32(resized, *img_f32, params.image_mean, params.image_std);
+                // res_imgs->data[0] = *res;
+                res_imgs->entries.push_back(std::move(img_f32));
+            } break;
 
-    } else if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL) {
-        clip_image_u8 resized;
-        auto patch_size = params.patch_size * 2;
-        auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, patch_size, params.image_size);
-        image_manipulation::bicubic_resize(*img, resized, new_size.width, new_size.height);
-
-        clip_image_f32_ptr img_f32(clip_image_f32_init());
-        // clip_image_f32_ptr res(clip_image_f32_init());
-        normalize_image_u8_to_f32(resized, *img_f32, params.image_mean, params.image_std);
-        // res_imgs->data[0] = *res;
-        res_imgs->entries.push_back(std::move(img_f32));
-        return true;
-    } else if (ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3) {
-        // The refined size has two steps:
-        // 1. Resize w/ aspect-ratio preserving such that the longer side is
-        //      the preprocessor longest size
-        // 2. Resize w/out preserving aspect ratio such that both sides are
-        //      multiples of image_size (always rounding up)
-        //
-        // CITE: https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics3/image_processing_idefics3.py#L737
-        const clip_image_size refined_size = image_manipulation::calc_size_preserved_ratio(
-            original_size, params.image_size, params.preproc_image_size);
-        // LOG_INF("%s: original size: %d x %d, refined size: %d x %d\n",
-        //         __func__, original_size.width, original_size.height,
-        //         refined_size.width, refined_size.height);
-
-        llava_uhd::slice_instructions instructions;
-        instructions.overview_size = clip_image_size{params.image_size, params.image_size};
-        instructions.refined_size = refined_size;
-        instructions.grid_size = clip_image_size{
-            static_cast<int>(std::ceil(static_cast<float>(refined_size.width) / params.image_size)),
-            static_cast<int>(std::ceil(static_cast<float>(refined_size.height) / params.image_size)),
-        };
-        for (int y = 0; y < refined_size.height; y += params.image_size) {
-            for (int x = 0; x < refined_size.width; x += params.image_size) {
-                // LOG_INF("%s: adding slice at x=%d, y=%d\n", __func__, x, y);
-                instructions.slices.push_back(llava_uhd::slice_coordinates{
-                    /* x    */x,
-                    /* y    */y,
-                    /* size */clip_image_size{
-                        std::min(params.image_size, refined_size.width - x),
-                        std::min(params.image_size, refined_size.height - y)
+        case PROJECTOR_TYPE_IDEFICS3:
+            {
+                // The refined size has two steps:
+                // 1. Resize w/ aspect-ratio preserving such that the longer side is
+                //      the preprocessor longest size
+                // 2. Resize w/out preserving aspect ratio such that both sides are
+                //      multiples of image_size (always rounding up)
+                //
+                // CITE: https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics3/image_processing_idefics3.py#L737
+                const clip_image_size refined_size = img_tool::calc_size_preserved_ratio(
+                    original_size, params.image_size, params.image_longest_edge);
+                // LOG_INF("%s: original size: %d x %d, refined size: %d x %d\n",
+                //         __func__, original_size.width, original_size.height,
+                //         refined_size.width, refined_size.height);
+
+                llava_uhd::slice_instructions instructions;
+                instructions.overview_size = clip_image_size{params.image_size, params.image_size};
+                instructions.refined_size = refined_size;
+                instructions.grid_size = clip_image_size{
+                    static_cast<int>(std::ceil(static_cast<float>(refined_size.width) / params.image_size)),
+                    static_cast<int>(std::ceil(static_cast<float>(refined_size.height) / params.image_size)),
+                };
+                for (int y = 0; y < refined_size.height; y += params.image_size) {
+                    for (int x = 0; x < refined_size.width; x += params.image_size) {
+                        // LOG_INF("%s: adding slice at x=%d, y=%d\n", __func__, x, y);
+                        instructions.slices.push_back(llava_uhd::slice_coordinates{
+                            /* x    */x,
+                            /* y    */y,
+                            /* size */clip_image_size{
+                                std::min(params.image_size, refined_size.width - x),
+                                std::min(params.image_size, refined_size.height - y)
+                            }
+                        });
                     }
-                });
-            }
-        }
-        auto imgs = llava_uhd::slice_image(img, instructions);
-
-        // cast and normalize to f32
-        for (size_t i = 0; i < imgs.size(); ++i) {
-            // clip_image_save_to_bmp(*imgs[i], "slice_" + std::to_string(i) + ".bmp");
-            clip_image_f32_ptr res(clip_image_f32_init());
-            normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std);
-            res_imgs->entries.push_back(std::move(res));
-        }
-
-        res_imgs->grid_x = instructions.grid_size.width;
-        res_imgs->grid_y = instructions.grid_size.height;
-        return true;
-    } else if (ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE
-            || ctx->proj_type() == PROJECTOR_TYPE_GEMMA3
-            || ctx->proj_type() == PROJECTOR_TYPE_INTERNVL // TODO @ngxson : support dynamic resolution
-    ) {
-        clip_image_u8 resized_image;
-        int sz = params.image_size;
-        image_manipulation::resize_and_pad_image(*img, resized_image, {sz, sz});
-        clip_image_f32_ptr img_f32(clip_image_f32_init());
-        //clip_image_save_to_bmp(resized_image, "resized.bmp");
-        normalize_image_u8_to_f32(resized_image, *img_f32, params.image_mean, params.image_std);
-        res_imgs->entries.push_back(std::move(img_f32));
-        return true;
-
-    } else if (ctx->proj_type() == PROJECTOR_TYPE_PIXTRAL
-            || ctx->proj_type() == PROJECTOR_TYPE_LIGHTONOCR
-    ) {
-        clip_image_u8 resized_image;
-        auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, params.patch_size, params.image_size);
-        image_manipulation::bilinear_resize(*img, resized_image, new_size.width, new_size.height);
-        clip_image_f32_ptr img_f32(clip_image_f32_init());
-        normalize_image_u8_to_f32(resized_image, *img_f32, params.image_mean, params.image_std);
-        res_imgs->entries.push_back(std::move(img_f32));
-        return true;
-
-    } else if (ctx->proj_type() == PROJECTOR_TYPE_LLAMA4) {
-        GGML_ASSERT(!params.image_res_candidates.empty());
-        auto const inst = llava_uhd::get_slice_instructions(ctx, original_size);
-        std::vector<clip_image_u8_ptr> imgs = llava_uhd::slice_image(img, inst);
-
-        for (size_t i = 0; i < imgs.size(); ++i) {
-            clip_image_f32_ptr res(clip_image_f32_init());
-            normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std);
-            res_imgs->entries.push_back(std::move(res));
-        }
-
-        res_imgs->grid_x = inst.grid_size.width;
-        res_imgs->grid_y = inst.grid_size.height;
-        return true;
-
-    } else if ( ctx->proj_type() == PROJECTOR_TYPE_LFM2
-             || ctx->proj_type() == PROJECTOR_TYPE_KIMIVL
-    ) {
-        GGML_ASSERT(params.proj_scale_factor);
+                }
+                auto imgs = llava_uhd::slice_image(img, instructions);
+
+                // cast and normalize to f32
+                for (size_t i = 0; i < imgs.size(); ++i) {
+                    // clip_image_save_to_bmp(*imgs[i], "slice_" + std::to_string(i) + ".bmp");
+                    clip_image_f32_ptr res(clip_image_f32_init());
+                    normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std);
+                    res_imgs->entries.push_back(std::move(res));
+                }
 
-        // smart resize
-        const int width = img->nx;
-        const int height = img->ny;
-        const int total_factor = params.patch_size * params.proj_scale_factor;
-        constexpr int min_image_tokens = 64;
-        constexpr int max_image_tokens = 1024;
-        const float min_pixels = min_image_tokens * total_factor * total_factor;
-        const float max_pixels = max_image_tokens * total_factor * total_factor;
+                res_imgs->grid_x = instructions.grid_size.width;
+                res_imgs->grid_y = instructions.grid_size.height;
+            } break;
 
-        auto round_by_factor = [f = total_factor](float x) { return static_cast<int>(std::nearbyintf(x / static_cast<float>(f))) * f; };
-        auto ceil_by_factor  = [f = total_factor](float x) { return static_cast<int>(std::ceil(x / static_cast<float>(f))) * f; };
-        auto floor_by_factor = [f = total_factor](float x) { return static_cast<int>(std::floor(x / static_cast<float>(f))) * f; };
+        case PROJECTOR_TYPE_GLM_EDGE:
+        case PROJECTOR_TYPE_GEMMA3:
+        case PROJECTOR_TYPE_INTERNVL: // TODO @ngxson : support dynamic resolution
+            {
+                clip_image_u8 resized_image;
+                int sz = params.image_size;
+                img_tool::resize(*img, resized_image, {sz, sz}, img_tool::RESIZE_ALGO_BILINEAR);
+                clip_image_f32_ptr img_f32(clip_image_f32_init());
+                //clip_image_save_to_bmp(resized_image, "resized.bmp");
+                normalize_image_u8_to_f32(resized_image, *img_f32, params.image_mean, params.image_std);
+                res_imgs->entries.push_back(std::move(img_f32));
+            } break;
 
-        int h_bar = std::max(total_factor, round_by_factor(height));
-        int w_bar = std::max(total_factor, round_by_factor(width));
+        case PROJECTOR_TYPE_PIXTRAL:
+        case PROJECTOR_TYPE_LIGHTONOCR:
+            {
+                GGML_ASSERT(params.image_min_pixels && params.image_max_pixels);
+                clip_image_u8 resized_image;
+                // the original pixtral model doesn't have n_merge
+                const int cur_merge = params.n_merge == 0 ? 1 : params.n_merge;
+                const clip_image_size target_size = img_tool::calc_size_preserved_ratio(
+                    original_size,
+                    params.patch_size * cur_merge,
+                    params.image_min_pixels,
+                    params.image_max_pixels);
+                img_tool::resize(*img, resized_image, target_size, img_tool::RESIZE_ALGO_BILINEAR);
+                clip_image_f32_ptr img_f32(clip_image_f32_init());
+                normalize_image_u8_to_f32(resized_image, *img_f32, params.image_mean, params.image_std);
+                res_imgs->entries.push_back(std::move(img_f32));
+            } break;
 
-        if (h_bar * w_bar > max_pixels) {
-            const auto beta = std::sqrt((height * width) / max_pixels);
-            h_bar = std::max(total_factor, floor_by_factor(height / beta));
-            w_bar = std::max(total_factor, floor_by_factor(width / beta));
-        } else if (h_bar * w_bar < min_pixels) {
-            const auto beta = std::sqrt(min_pixels / (height * width));
-            h_bar = ceil_by_factor(height * beta);
-            w_bar = ceil_by_factor(width * beta);
-        }
+        case PROJECTOR_TYPE_LLAMA4:
+            {
+                GGML_ASSERT(!params.image_res_candidates.empty());
+                auto const inst = llava_uhd::get_slice_instructions(ctx, original_size);
+                std::vector<clip_image_u8_ptr> imgs = llava_uhd::slice_image(img, inst);
+
+                for (size_t i = 0; i < imgs.size(); ++i) {
+                    clip_image_f32_ptr res(clip_image_f32_init());
+                    normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std);
+                    res_imgs->entries.push_back(std::move(res));
+                }
 
-        const std::array<uint8_t, 3> pad_color = {122, 116, 104};
+                res_imgs->grid_x = inst.grid_size.width;
+                res_imgs->grid_y = inst.grid_size.height;
+            } break;
 
-        clip_image_u8 resized_img;
-        image_manipulation::resize_and_pad_image(*img, resized_img, clip_image_size{w_bar, h_bar}, pad_color);
-        clip_image_f32_ptr res(clip_image_f32_init());
-        normalize_image_u8_to_f32(resized_img, *res, params.image_mean, params.image_std);
-        res_imgs->entries.push_back(std::move(res));
-        return true;
-    }
+        case PROJECTOR_TYPE_LFM2:
+        case PROJECTOR_TYPE_KIMIVL:
+            {
+                GGML_ASSERT(params.image_min_pixels && params.image_max_pixels);
+                const clip_image_size target_size = img_tool::calc_size_preserved_ratio(
+                    original_size,
+                    params.patch_size * params.n_merge,
+                    params.image_min_pixels,
+                    params.image_max_pixels);
+                const std::array<uint8_t, 3> pad_color = {122, 116, 104};
+
+                clip_image_u8 resized_img;
+                img_tool::resize(*img, resized_img, target_size, img_tool::RESIZE_ALGO_BILINEAR, true, pad_color);
+                clip_image_f32_ptr res(clip_image_f32_init());
+                normalize_image_u8_to_f32(resized_img, *res, params.image_mean, params.image_std);
+                res_imgs->entries.push_back(std::move(res));
+            } break;
 
-    // the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104)
-    // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156
+        case PROJECTOR_TYPE_MLP:
+        case PROJECTOR_TYPE_MLP_NORM:
+        case PROJECTOR_TYPE_LDP:
+        case PROJECTOR_TYPE_LDPV2:
+        case PROJECTOR_TYPE_COGVLM: // TODO @ngxson : is this correct for cogvlm?
+            {
+                // TODO @ngxson : refactor the code below to avoid duplicated logic
 
-    clip_image_u8_ptr temp(clip_image_u8_init()); // we will keep the input image data here temporarily
+                // the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104)
+                // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156
 
-    if (pad_to_square) {
-        // for llava-1.5, we resize image to a square, and pad the shorter side with a background color
-        // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156
-        const int longer_side = std::max(img->nx, img->ny);
-        temp->nx = longer_side;
-        temp->ny = longer_side;
-        temp->buf.resize(3 * longer_side * longer_side);
+                clip_image_u8_ptr temp(clip_image_u8_init()); // we will keep the input image data here temporarily
 
-        // background color in RGB from LLaVA (this is the mean rgb color * 255)
-        const std::array<uint8_t, 3> pad_color = {122, 116, 104};
+                // The model config actually contains all we need to decide on how to preprocess, here we automatically switch to the new llava-1.6 preprocessing
+                if (params.image_res_candidates.empty()) { // pad_to_square
+                    // for llava-1.5, we resize image to a square, and pad the shorter side with a background color
+                    // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156
+                    const int longer_side = std::max(img->nx, img->ny);
+                    temp->nx = longer_side;
+                    temp->ny = longer_side;
+                    temp->buf.resize(3 * longer_side * longer_side);
 
-        // resize the image to the target_size
-        image_manipulation::resize_and_pad_image(*img, *temp, clip_image_size{params.image_size, params.image_size}, pad_color);
+                    // background color in RGB from LLaVA (this is the mean rgb color * 255)
+                    const std::array<uint8_t, 3> pad_color = {122, 116, 104};
 
-        clip_image_f32_ptr res(clip_image_f32_init());
-        normalize_image_u8_to_f32(*temp, *res, params.image_mean, params.image_std);
-        res_imgs->entries.push_back(std::move(res));
-        return true;
+                    // resize the image to the target_size
+                    img_tool::resize(*img, *temp, clip_image_size{params.image_size, params.image_size}, img_tool::RESIZE_ALGO_BILINEAR, true, pad_color);
 
-    } else if (!params.image_res_candidates.empty()) {
-        // "spatial_unpad" with "anyres" processing for llava-1.6
-        auto const inst = llava_uhd::get_slice_instructions(ctx, original_size);
-        std::vector<clip_image_u8_ptr> imgs = llava_uhd::slice_image(img, inst);
+                    clip_image_f32_ptr res(clip_image_f32_init());
+                    normalize_image_u8_to_f32(*temp, *res, params.image_mean, params.image_std);
+                    res_imgs->entries.push_back(std::move(res));
 
-        for (size_t i = 0; i < imgs.size(); ++i) {
-            // clip_image_save_to_bmp(*imgs[i], "slice_" + std::to_string(i) + ".bmp");
-            clip_image_f32_ptr res(clip_image_f32_init());
-            normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std);
-            res_imgs->entries.push_back(std::move(res));
-        }
+                } else {
+                    // "spatial_unpad" with "anyres" processing for llava-1.6
+                    auto const inst = llava_uhd::get_slice_instructions(ctx, original_size);
+                    std::vector<clip_image_u8_ptr> imgs = llava_uhd::slice_image(img, inst);
+
+                    for (size_t i = 0; i < imgs.size(); ++i) {
+                        // clip_image_save_to_bmp(*imgs[i], "slice_" + std::to_string(i) + ".bmp");
+                        clip_image_f32_ptr res(clip_image_f32_init());
+                        normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std);
+                        res_imgs->entries.push_back(std::move(res));
+                    }
+                }
+            } break;
 
-        return true;
-    } else {
-        GGML_ABORT("Unknown image preprocessing type");
+        default:
+            LOG_ERR("%s: unsupported projector type %d\n", __func__, ctx->proj_type());
+            return false;
     }
 
+    return true;
 }
 
 ggml_tensor * clip_get_newline_tensor(const struct clip_ctx * ctx) {
@@ -4145,7 +4247,7 @@ int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 *
     const auto & params = ctx->model.hparams;
     const int n_total = clip_n_output_tokens(ctx, img);
     if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL) {
-        return img->nx / (params.patch_size * 2) + (int)(img->nx % params.patch_size > 0);
+        return img->nx / (params.patch_size * 2);
     }
     return n_total;
 }
@@ -4153,7 +4255,7 @@ int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 *
 int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
     const auto & params = ctx->model.hparams;
     if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN25VL || ctx->proj_type() == PROJECTOR_TYPE_QWEN3VL) {
-        return img->ny / (params.patch_size * 2) + (int)(img->ny % params.patch_size > 0);
+        return img->ny / (params.patch_size * 2);
     }
     return 1;
 }
@@ -4211,9 +4313,8 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
         case PROJECTOR_TYPE_QWEN3VL:
             {
                 // dynamic size (2 conv, so double patch size)
-                int patch_size = params.patch_size * 2;
-                int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
-                int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0);
+                int x_patch = img->nx / (params.patch_size * 2);
+                int y_patch = img->ny / (params.patch_size * 2);
                 n_patches = x_patch * y_patch;
             } break;
         case PROJECTOR_TYPE_GEMMA3:
@@ -4222,15 +4323,14 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
         case PROJECTOR_TYPE_LLAMA4:
             {
                 // both X and Y are downscaled by the scale factor
-                int scale_factor = ctx->model.hparams.proj_scale_factor;
+                int scale_factor = ctx->model.hparams.n_merge;
                 n_patches /= (scale_factor * scale_factor);
             } break;
         case PROJECTOR_TYPE_LFM2:
         case PROJECTOR_TYPE_KIMIVL:
             {
                 // dynamic size
-                int scale_factor = ctx->model.hparams.proj_scale_factor;
-                int out_patch_size = params.patch_size * scale_factor;
+                int out_patch_size = params.patch_size * ctx->model.hparams.n_merge;
                 int x_patch = CLIP_ALIGN(img->nx, out_patch_size) / out_patch_size;
                 int y_patch = CLIP_ALIGN(img->ny, out_patch_size) / out_patch_size;
                 n_patches = x_patch * y_patch;
@@ -4239,7 +4339,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
         case PROJECTOR_TYPE_LIGHTONOCR:
             {
                 // dynamic size
-                int n_merge = params.spatial_merge_size;
+                int n_merge = ctx->model.hparams.n_merge;
                 int n_patches_x = img->nx / patch_size / (n_merge > 0 ? n_merge : 1);
                 int n_patches_y = img->ny / patch_size / (n_merge > 0 ? n_merge : 1);
                 if (ctx->model.token_embd_img_break) {