]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
mtmd: add --image-min/max-tokens (#16921)
authorXuan-Son Nguyen <redacted>
Mon, 3 Nov 2025 10:11:18 +0000 (11:11 +0100)
committerGitHub <redacted>
Mon, 3 Nov 2025 10:11:18 +0000 (11:11 +0100)
common/arg.cpp
common/common.h
tools/mtmd/clip.cpp
tools/mtmd/clip.h
tools/mtmd/mtmd-cli.cpp
tools/mtmd/mtmd.cpp
tools/mtmd/mtmd.h
tools/server/server.cpp

index d8f9bbd24301fcc261dd7c4c2760e56511ce9a65..4316917d7459548d61860497c8846f74e9610208 100644 (file)
@@ -2768,6 +2768,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
             params.image.emplace_back(value);
         }
     ).set_examples({LLAMA_EXAMPLE_MTMD}));
+    add_opt(common_arg(
+        {"--image-min-tokens"}, "N",
+        "minimum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)",
+        [](common_params & params, int value) {
+            params.image_min_tokens = value;
+        }
+    ).set_examples(mmproj_examples).set_env("LLAMA_ARG_IMAGE_MIN_TOKENS"));
+    add_opt(common_arg(
+        {"--image-max-tokens"}, "N",
+        "maximum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)",
+        [](common_params & params, int value) {
+            params.image_max_tokens = value;
+        }
+    ).set_examples(mmproj_examples).set_env("LLAMA_ARG_IMAGE_MAX_TOKENS"));
     if (llama_supports_rpc()) {
         add_opt(common_arg(
             {"--rpc"}, "SERVERS",
index a8cb630ea5805af828e6aeeab2dfdc75e27586d6..78c568a7bc62e270ac577463da6020f689d433dd 100644 (file)
@@ -406,6 +406,8 @@ struct common_params {
     bool mmproj_use_gpu = true;     // use GPU for multimodal model
     bool no_mmproj = false;         // explicitly disable multimodal model
     std::vector<std::string> image; // path to image file(s)
+    int image_min_tokens = -1;
+    int image_max_tokens = -1;
 
     // finetune
     struct lr_opt lr;
index 60516d582a5f36fcda47b1ca42adb484db8da76e..99775cb3e351c49186b22fb6ecbaafbddd480078 100644 (file)
@@ -169,8 +169,8 @@ struct clip_hparams {
     int32_t n_layer;
     // idefics3
     int32_t image_longest_edge = 0;
-    int32_t image_min_pixels = 0;
-    int32_t image_max_pixels = 0;
+    int32_t image_min_pixels = -1;
+    int32_t image_max_pixels = -1;
     int32_t n_merge = 0; // number of patch merges **per-side**
 
     float image_mean[3];
@@ -203,11 +203,15 @@ struct clip_hparams {
     int minicpmv_version = 0;
     int32_t minicpmv_query_num = 0;         // MiniCPM-V query number
 
+    // custom value provided by user, can be undefined if not set
+    int32_t custom_image_min_tokens = -1;
+    int32_t custom_image_max_tokens = -1;
+
     void set_limit_image_tokens(int n_tokens_min, int n_tokens_max) {
         const int cur_merge = n_merge == 0 ? 1 : n_merge;
         const int patch_area = patch_size * patch_size * cur_merge * cur_merge;
-        image_min_pixels = n_tokens_min * patch_area;
-        image_max_pixels = n_tokens_max * patch_area;
+        image_min_pixels = (custom_image_min_tokens > 0 ? custom_image_min_tokens : n_tokens_min) * patch_area;
+        image_max_pixels = (custom_image_max_tokens > 0 ? custom_image_max_tokens : n_tokens_max) * patch_area;
         warmup_image_size = static_cast<int>(std::sqrt(image_max_pixels));
     }
 
@@ -216,6 +220,7 @@ struct clip_hparams {
         GGML_ASSERT(n_tok_per_side * n_tok_per_side == n_tokens && "n_tokens must be n*n");
         const int cur_merge = n_merge == 0 ? 1 : n_merge;
         warmup_image_size = n_tok_per_side * patch_size * cur_merge;
+        // TODO: support warmup size for custom token numbers
     }
 };
 
@@ -459,6 +464,13 @@ struct clip_ctx {
             LOG_INF("%s: CLIP using CPU backend\n", __func__);
         }
 
+        if (ctx_params.image_min_tokens > 0) {
+            model.hparams.custom_image_min_tokens = ctx_params.image_min_tokens;
+        }
+        if (ctx_params.image_max_tokens > 0) {
+            model.hparams.custom_image_max_tokens = ctx_params.image_max_tokens;
+        }
+
         backend_ptrs.push_back(backend_cpu);
         backend_buft.push_back(ggml_backend_get_default_buffer_type(backend_cpu));
 
@@ -2786,6 +2798,12 @@ struct clip_model_loader {
                         //           see: https://github.com/ggml-org/llama.cpp/issues/16842#issuecomment-3475144858
                         hparams.set_limit_image_tokens(8, 2048);
                         hparams.set_warmup_n_tokens(256); // avoid OOM on warmup
+                        const int warn_min_pixels = 1024 * hparams.n_merge * hparams.n_merge * hparams.patch_size * hparams.patch_size;
+                        if (hparams.image_min_pixels < warn_min_pixels) {
+                            LOG_WRN("%s: Qwen-VL models require at minimum 1024 image tokens to function correctly on grounding tasks\n", __func__);
+                            LOG_WRN("%s: if you encounter problems with accuracy, try adding --image-min-tokens 1024\n", __func__);
+                            LOG_WRN("%s: more info: https://github.com/ggml-org/llama.cpp/issues/16842\n\n", __func__);
+                        }
                     } break;
                 case PROJECTOR_TYPE_LLAMA4:
                     {
@@ -2810,6 +2828,13 @@ struct clip_model_loader {
                     break;
             }
 
+            // sanity check
+            {
+                if (hparams.image_max_pixels < hparams.image_min_pixels) {
+                    throw std::runtime_error(string_format("%s: image_max_pixels (%d) is less than image_min_pixels (%d)\n", __func__, hparams.image_max_pixels, hparams.image_min_pixels));
+                }
+            }
+
             LOG_INF("%s: projector:          %s\n", __func__, proj_type.c_str());
             LOG_INF("%s: n_embd:             %d\n", __func__, hparams.n_embd);
             LOG_INF("%s: n_head:             %d\n", __func__, hparams.n_head);
@@ -2826,10 +2851,10 @@ struct clip_model_loader {
                 LOG_INF("%s: n_merge:            %d\n", __func__, hparams.n_merge);
                 LOG_INF("%s: n_wa_pattern:       %d\n", __func__, hparams.n_wa_pattern);
                 if (hparams.image_min_pixels > 0) {
-                    LOG_INF("%s: image_min_pixels:   %d\n", __func__, hparams.image_min_pixels);
+                    LOG_INF("%s: image_min_pixels:   %d%s\n", __func__, hparams.image_min_pixels, hparams.custom_image_min_tokens > 0 ? " (custom value)" : "");
                 }
                 if (hparams.image_max_pixels > 0) {
-                    LOG_INF("%s: image_max_pixels:   %d\n", __func__, hparams.image_max_pixels);
+                    LOG_INF("%s: image_max_pixels:   %d%s\n", __func__, hparams.image_max_pixels, hparams.custom_image_max_tokens > 0 ? " (custom value)" : "");
                 }
             } else if (is_audio) {
                 LOG_INF("\n--- audio hparams ---\n");
@@ -4169,7 +4194,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
         case PROJECTOR_TYPE_QWEN25VL:
         case PROJECTOR_TYPE_QWEN3VL:
             {
-                // step 1: make a blank canvas which aligns to the grid
+                GGML_ASSERT(params.image_min_pixels > 0 && params.image_max_pixels > 0);
                 clip_image_u8 resized;
                 const clip_image_size new_size = img_tool::calc_size_preserved_ratio(
                     original_size,
@@ -4262,7 +4287,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
         case PROJECTOR_TYPE_PIXTRAL:
         case PROJECTOR_TYPE_LIGHTONOCR:
             {
-                GGML_ASSERT(params.image_min_pixels && params.image_max_pixels);
+                GGML_ASSERT(params.image_min_pixels > 0 && params.image_max_pixels > 0);
                 clip_image_u8 resized_image;
                 // the original pixtral model doesn't have n_merge
                 const int cur_merge = params.n_merge == 0 ? 1 : params.n_merge;
@@ -4296,7 +4321,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
         case PROJECTOR_TYPE_LFM2:
         case PROJECTOR_TYPE_KIMIVL:
             {
-                GGML_ASSERT(params.image_min_pixels && params.image_max_pixels);
+                GGML_ASSERT(params.image_min_pixels > 0 && params.image_max_pixels > 0);
                 const clip_image_size target_size = img_tool::calc_size_preserved_ratio(
                     original_size,
                     params.patch_size * params.n_merge,
index 6384e2adaf77535a60275886fb360882d6d243e9..3e4c985f117b905eee87d4a096b818923d26c896 100644 (file)
@@ -33,6 +33,8 @@ struct clip_context_params {
     bool use_gpu;
     enum ggml_log_level verbosity;
     enum clip_flash_attn_type flash_attn_type;
+    int image_min_tokens;
+    int image_max_tokens;
 };
 
 struct clip_init_result {
index 17aea1472b3c6b2826ec6a137f7a1ca905b899ab..3e19e95958a2f413995e86212e7d8923a799b9a8 100644 (file)
@@ -132,11 +132,13 @@ struct mtmd_cli_context {
     void init_vision_context(common_params & params) {
         const char * clip_path = params.mmproj.path.c_str();
         mtmd_context_params mparams = mtmd_context_params_default();
-        mparams.use_gpu = params.mmproj_use_gpu;
-        mparams.print_timings = true;
-        mparams.n_threads = params.cpuparams.n_threads;
-        mparams.verbosity = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
-        mparams.flash_attn_type = params.flash_attn_type;
+        mparams.use_gpu          = params.mmproj_use_gpu;
+        mparams.print_timings    = true;
+        mparams.n_threads        = params.cpuparams.n_threads;
+        mparams.verbosity        = params.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
+        mparams.flash_attn_type  = params.flash_attn_type;
+        mparams.image_min_tokens = params.image_min_tokens;
+        mparams.image_max_tokens = params.image_max_tokens;
         ctx_vision.reset(mtmd_init_from_file(clip_path, model, mparams));
         if (!ctx_vision.get()) {
             LOG_ERR("Failed to load vision model from %s\n", clip_path);
index 297eef437ab912e1bf0cdb22f537a337949364bc..325f7ff995e362d2a25d2b06d44bc36dcf594adb 100644 (file)
@@ -109,6 +109,8 @@ mtmd_context_params mtmd_context_params_default() {
     params.image_marker = MTMD_DEFAULT_IMAGE_MARKER;
     params.media_marker = mtmd_default_marker();
     params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
+    params.image_min_tokens = -1;
+    params.image_max_tokens = -1;
     return params;
 }
 
@@ -171,9 +173,13 @@ struct mtmd_context {
         }
 
         clip_context_params ctx_clip_params;
-        ctx_clip_params.use_gpu   = ctx_params.use_gpu;
-        ctx_clip_params.verbosity = ctx_params.verbosity;
-        ctx_clip_params.flash_attn_type = mtmd_get_clip_flash_attn_type(ctx_params.flash_attn_type);
+        ctx_clip_params.use_gpu          = ctx_params.use_gpu;
+        ctx_clip_params.verbosity        = ctx_params.verbosity;
+        ctx_clip_params.flash_attn_type  = mtmd_get_clip_flash_attn_type(ctx_params.flash_attn_type);
+        // custom image token limits
+        ctx_clip_params.image_min_tokens = ctx_params.image_min_tokens;
+        ctx_clip_params.image_max_tokens = ctx_params.image_max_tokens;
+
         auto res = clip_init(mmproj_fname, ctx_clip_params);
         ctx_v = res.ctx_v;
         ctx_a = res.ctx_a;
index 4ae1925bcdfb64775ababd8f32d0a85a219b5724..775fba6215c7ca72aaad29ead9e9948aaf3f10d8 100644 (file)
@@ -83,6 +83,10 @@ struct mtmd_context_params {
     const char * image_marker; // deprecated, use media_marker instead
     const char * media_marker;
     enum llama_flash_attn_type flash_attn_type;
+
+    // limit number of image tokens, only for vision models with dynamic resolution
+    int image_min_tokens; // minimum number of tokens for image input (default: read from metadata)
+    int image_max_tokens; // maximum number of tokens for image input (default: read from metadata)
 };
 
 MTMD_API const char * mtmd_default_marker(void);
index a9bef35189b3ab877b82153366fccf51e5e2885c..a8d7773c96809ff88d2855b36d49e064eca62439 100644 (file)
@@ -2452,11 +2452,13 @@ struct server_context {
         std::string & mmproj_path = params_base.mmproj.path;
         if (!mmproj_path.empty()) {
             mtmd_context_params mparams = mtmd_context_params_default();
-            mparams.use_gpu       = params_base.mmproj_use_gpu;
-            mparams.print_timings = false;
-            mparams.n_threads     = params_base.cpuparams.n_threads;
-            mparams.verbosity     = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
-            mparams.flash_attn_type = params_base.flash_attn_type;
+            mparams.use_gpu          = params_base.mmproj_use_gpu;
+            mparams.print_timings    = false;
+            mparams.n_threads        = params_base.cpuparams.n_threads;
+            mparams.verbosity        = params_base.verbosity > 0 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_INFO;
+            mparams.flash_attn_type  = params_base.flash_attn_type;
+            mparams.image_min_tokens = params_base.image_min_tokens;
+            mparams.image_max_tokens = params_base.image_max_tokens;
             mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams);
             if (mctx == nullptr) {
                 SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str());