]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
mtmd : add vision support for llama 4 (#13282)
authorXuan-Son Nguyen <redacted>
Mon, 19 May 2025 11:04:14 +0000 (13:04 +0200)
committerGitHub <redacted>
Mon, 19 May 2025 11:04:14 +0000 (13:04 +0200)
* wip llama 4 conversion

* rm redundant __init__

* fix conversion

* fix conversion

* test impl

* try this

* reshape patch_embeddings_0

* fix view

* rm ffn_post_norm

* cgraph ok

* f32 for pos embd

* add image marker tokens

* Llama4UnfoldConvolution

* correct pixel shuffle

* fix merge conflicts

* correct

* add debug_graph

* logits matched, but it still preceives the image incorrectly

* fix style

* add image_grid_pinpoints

* handle llama 4 preprocessing

* rm load_image_size

* rm unused line

* fix

* small fix 2

* add test & docs

* fix llava-1.6 test

* test: add notion of huge models

* add comment

* add warn about degraded quality

convert_hf_to_gguf.py
docs/multimodal.md
gguf-py/gguf/constants.py
gguf-py/gguf/tensor_mapping.py
tools/mtmd/clip-impl.h
tools/mtmd/clip.cpp
tools/mtmd/clip.h
tools/mtmd/mtmd.cpp
tools/mtmd/tests.sh

index b5402cbead880aab3e89908c36c3e698031101c0..15e019a10f253fa24c071c066254b81a629e0134 100755 (executable)
@@ -308,6 +308,7 @@ class ModelBase:
                             gguf.MODEL_TENSOR.TIME_MIX_LERP_FUSED,
                             gguf.MODEL_TENSOR.POSNET_NORM1,
                             gguf.MODEL_TENSOR.POSNET_NORM2,
+                            gguf.MODEL_TENSOR.V_ENC_EMBD_POS,
                         )
                     )
                     or not new_name.endswith(".weight")
@@ -2092,6 +2093,26 @@ class Llama4Model(LlamaModel):
         return super().modify_tensors(data_torch, name, bid)
 
 
+@ModelBase.register("Llama4ForConditionalGeneration")
+class Llama4VisionModel(VisionModel):
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.LLAMA4)
+        self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams["norm_eps"])
+        self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / self.hparams["pixel_shuffle_ratio"]))
+        assert self.hparams["hidden_act"] == "gelu"
+        self.gguf_writer.add_vision_use_gelu(True)
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        del bid # unused
+        if "multi_modal_projector" in name or "vision_model" in name:
+            # process vision tensors
+            if "positional_embedding_vlm" in name and ".weight" not in name:
+                name += ".weight"
+            return [(self.map_tensor_name(name), data_torch)]
+        return []
+
+
 @ModelBase.register("Mistral3ForConditionalGeneration")
 class Mistral3Model(LlamaModel):
     model_arch = gguf.MODEL_ARCH.LLAMA
index 80014ba1cef6da04db5619d203e989d92d1131c7..054778e91f2979198a47e7bb45f744ada3d0ae53 100644 (file)
@@ -74,4 +74,7 @@ NOTE: some models may require large context window, for example: `-c 8192`
 (tool_name) -hf ggml-org/InternVL3-2B-Instruct-GGUF
 (tool_name) -hf ggml-org/InternVL3-8B-Instruct-GGUF
 (tool_name) -hf ggml-org/InternVL3-14B-Instruct-GGUF
+
+# Llama 4 Scout
+(tool_name) -hf ggml-org/Llama-4-Scout-17B-16E-Instruct-GGUF
 ```
index 21af0a9a2693f89dc794d98f49776ee93fa6fcc7..7aed8b83ec3784a83d04ed06621184004c4267d0 100644 (file)
@@ -482,14 +482,15 @@ class MODEL_TENSOR(IntEnum):
     V_ENC_EMBD_CLS       = auto()
     V_ENC_EMBD_PATCH     = auto()
     V_ENC_EMBD_POS       = auto()
+    V_ENC_INPUT_NORM     = auto()
     V_ENC_ATTN_Q         = auto()
     V_ENC_ATTN_Q_NORM    = auto()
     V_ENC_ATTN_K         = auto()
     V_ENC_ATTN_K_NORM    = auto()
     V_ENC_ATTN_V         = auto()
-    V_ENC_INPUT_NORM     = auto()
-    V_ENC_OUTPUT         = auto()
-    V_ENC_OUTPUT_NORM    = auto()
+    V_ENC_ATTN_O         = auto()
+    V_ENC_ATTN_O_NORM    = auto()
+    V_ENC_POST_ATTN_NORM = auto()
     V_ENC_FFN_UP         = auto()
     V_ENC_FFN_GATE       = auto()
     V_ENC_FFN_DOWN       = auto()
@@ -749,8 +750,9 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
     MODEL_TENSOR.V_ENC_ATTN_K_NORM:         "v.blk.{bid}.attn_k_norm",
     MODEL_TENSOR.V_ENC_ATTN_V:              "v.blk.{bid}.attn_v",
     MODEL_TENSOR.V_ENC_INPUT_NORM:          "v.blk.{bid}.ln1",
-    MODEL_TENSOR.V_ENC_OUTPUT:              "v.blk.{bid}.attn_out",
-    MODEL_TENSOR.V_ENC_OUTPUT_NORM:         "v.blk.{bid}.ln2",
+    MODEL_TENSOR.V_ENC_ATTN_O:              "v.blk.{bid}.attn_out",
+    MODEL_TENSOR.V_ENC_ATTN_O_NORM:         "v.blk.{bid}.attn_out_norm",
+    MODEL_TENSOR.V_ENC_POST_ATTN_NORM:      "v.blk.{bid}.ln2",
     MODEL_TENSOR.V_ENC_FFN_UP:              "v.blk.{bid}.ffn_up",
     MODEL_TENSOR.V_ENC_FFN_GATE:            "v.blk.{bid}.ffn_gate",
     MODEL_TENSOR.V_ENC_FFN_DOWN:            "v.blk.{bid}.ffn_down",
@@ -785,14 +787,15 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.V_ENC_EMBD_CLS,
         MODEL_TENSOR.V_ENC_EMBD_PATCH,
         MODEL_TENSOR.V_ENC_EMBD_POS,
+        MODEL_TENSOR.V_ENC_INPUT_NORM,
         MODEL_TENSOR.V_ENC_ATTN_Q,
         MODEL_TENSOR.V_ENC_ATTN_Q_NORM,
         MODEL_TENSOR.V_ENC_ATTN_K,
         MODEL_TENSOR.V_ENC_ATTN_K_NORM,
         MODEL_TENSOR.V_ENC_ATTN_V,
-        MODEL_TENSOR.V_ENC_INPUT_NORM,
-        MODEL_TENSOR.V_ENC_OUTPUT,
-        MODEL_TENSOR.V_ENC_OUTPUT_NORM,
+        MODEL_TENSOR.V_ENC_ATTN_O,
+        MODEL_TENSOR.V_ENC_ATTN_O_NORM,
+        MODEL_TENSOR.V_ENC_POST_ATTN_NORM,
         MODEL_TENSOR.V_ENC_FFN_UP,
         MODEL_TENSOR.V_ENC_FFN_GATE,
         MODEL_TENSOR.V_ENC_FFN_DOWN,
@@ -2180,6 +2183,7 @@ class VisionProjectorType:
     GEMMA3 = "gemma3"
     IDEFICS3 = "idefics3"
     PIXTRAL = "pixtral"
+    LLAMA4 = "llama4"
     QWEN2VL = "qwen2vl_merger"
     QWEN25VL = "qwen2.5vl_merger"
     INTERNVL = "internvl"
index dadb4fd94934eb8f5f221b1f74db08f87f01e0cd..0298f8b460a3f7694b2f674f5b97223d96af77ca 100644 (file)
@@ -902,10 +902,12 @@ class TensorNameMap:
 
         MODEL_TENSOR.V_MMPROJ_FC: (
             "model.connector.modality_projection.proj", # SmolVLM
+            "multi_modal_projector.linear_1", # llama 4
         ),
 
         MODEL_TENSOR.V_MMPROJ_MLP: (
             "model.mm_projector.mlp.mlp.{bid}",
+            "vision_model.vision_adapter.mlp.fc{bid}", # llama 4
             "mlp1.{bid}", # InternVL
         ),
 
@@ -915,6 +917,7 @@ class TensorNameMap:
 
         MODEL_TENSOR.V_ENC_EMBD_CLS: (
             "vision_tower.vision_model.embeddings.class_embedding",
+            "vision_model.class_embedding", # llama 4
         ),
 
         MODEL_TENSOR.V_ENC_EMBD_PATCH: (
@@ -922,6 +925,7 @@ class TensorNameMap:
             "vpm.embeddings.patch_embedding",
             "model.vision_model.embeddings.patch_embedding", # SmolVLM
             "vision_tower.patch_conv", # pixtral
+            "vision_model.patch_embedding.linear", # llama 4
             "visual.patch_embed.proj", # qwen2vl
         ),
 
@@ -929,12 +933,14 @@ class TensorNameMap:
             "vision_tower.vision_model.embeddings.position_embedding",
             "vpm.embeddings.position_embedding",
             "model.vision_model.embeddings.position_embedding", # SmolVLM
+            "vision_model.positional_embedding_vlm", # llama 4
         ),
 
         MODEL_TENSOR.V_ENC_ATTN_Q: (
             "vision_tower.vision_model.encoder.layers.{bid}.self_attn.q_proj",
             "vpm.encoder.layers.{bid}.self_attn.q_proj",
             "model.vision_model.encoder.layers.{bid}.self_attn.q_proj", # SmolVLM
+            "vision_model.model.layers.{bid}.self_attn.q_proj", # llama4
             "vision_tower.transformer.layers.{bid}.attention.q_proj", # pixtral
             "visual.blocks.{bid}.attn.q", # qwen2vl, generated
         ),
@@ -947,6 +953,7 @@ class TensorNameMap:
             "vision_tower.vision_model.encoder.layers.{bid}.self_attn.k_proj",
             "vpm.encoder.layers.{bid}.self_attn.k_proj",
             "model.vision_model.encoder.layers.{bid}.self_attn.k_proj", # SmolVLM
+            "vision_model.model.layers.{bid}.self_attn.k_proj", # llama4
             "vision_tower.transformer.layers.{bid}.attention.k_proj", # pixtral
             "visual.blocks.{bid}.attn.k", # qwen2vl, generated
         ),
@@ -959,6 +966,7 @@ class TensorNameMap:
             "vision_tower.vision_model.encoder.layers.{bid}.self_attn.v_proj",
             "vpm.encoder.layers.{bid}.self_attn.v_proj",
             "model.vision_model.encoder.layers.{bid}.self_attn.v_proj", # SmolVLM
+            "vision_model.model.layers.{bid}.self_attn.v_proj", # llama4
             "vision_tower.transformer.layers.{bid}.attention.v_proj", # pixtral
             "visual.blocks.{bid}.attn.v", # qwen2vl, generated
         ),
@@ -969,23 +977,26 @@ class TensorNameMap:
             "vpm.encoder.layers.{bid}.layer_norm1",
             "model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM
             "vision_tower.transformer.layers.{bid}.attention_norm", # pixtral
+            "vision_model.model.layers.{bid}.input_layernorm", # llama4
             "visual.blocks.{bid}.norm1", # qwen2vl
         ),
 
-        MODEL_TENSOR.V_ENC_OUTPUT: (
+        MODEL_TENSOR.V_ENC_ATTN_O: (
             "vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj",
             "vision_tower.vision_model.encoder.layers.{bid}.attn.proj", # InternVL
             "vpm.encoder.layers.{bid}.self_attn.out_proj",
             "model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM
+            "vision_model.model.layers.{bid}.self_attn.o_proj", # llama4
             "vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral
             "visual.blocks.{bid}.attn.proj", # qwen2vl
         ),
 
-        MODEL_TENSOR.V_ENC_OUTPUT_NORM: (
+        MODEL_TENSOR.V_ENC_POST_ATTN_NORM: (
             "vision_tower.vision_model.encoder.layers.{bid}.layer_norm2",
             "vision_tower.vision_model.encoder.layers.{bid}.norm2", # InternVL
             "vpm.encoder.layers.{bid}.layer_norm2",
             "model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM
+            "vision_model.model.layers.{bid}.post_attention_layernorm", # llama4
             "vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral
             "visual.blocks.{bid}.norm2", # qwen2vl
         ),
@@ -995,6 +1006,7 @@ class TensorNameMap:
             "vpm.encoder.layers.{bid}.mlp.fc1",
             "model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3
             "vision_tower.transformer.layers.{bid}.feed_forward.up_proj", # pixtral
+            "vision_model.model.layers.{bid}.mlp.fc1", # llama4
             "visual.blocks.{bid}.mlp.fc1", # qwen2vl
             "visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl
         ),
@@ -1009,6 +1021,7 @@ class TensorNameMap:
             "vpm.encoder.layers.{bid}.mlp.fc2",
             "model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3
             "vision_tower.transformer.layers.{bid}.feed_forward.down_proj", # pixtral
+            "vision_model.model.layers.{bid}.mlp.fc2", # llama4
             "visual.blocks.{bid}.mlp.fc2", # qwen2vl
             "visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl
         ),
@@ -1024,11 +1037,13 @@ class TensorNameMap:
         MODEL_TENSOR.V_PRE_NORM: (
             "vision_tower.vision_model.pre_layrnorm",
             "vision_tower.ln_pre", # pixtral
+            "vision_model.layernorm_pre", # llama4
         ),
 
         MODEL_TENSOR.V_POST_NORM: (
             "vision_tower.vision_model.post_layernorm",
             "model.vision_model.post_layernorm", # SmolVLM
+            "vision_model.layernorm_post", # llama4
             "visual.merger.ln_q", # qwen2vl
         ),
 
index 23036ba72f1c1abe18beca652da801a98cb3d04f..7b7d2df39622cb844ba8183a494fddddeb3737ba 100644 (file)
@@ -4,6 +4,7 @@
 
 #include <climits>
 #include <cstdarg>
+#include <cinttypes>
 #include <string>
 #include <map>
 #include <sstream>
@@ -44,7 +45,7 @@
 // tensor name constants
 //
 
-#define TN_POS_EMBD        "%s.position_embd.weight"
+#define TN_POS_EMBD        "v.position_embd.weight"
 #define TN_CLASS_EMBD      "v.class_embd"
 #define TN_PATCH_EMBD      "v.patch_embd.weight"  // not rename tensor with ".0" postfix for backwrad compat
 #define TN_PATCH_EMBD_1    "v.patch_embd.weight.1"
@@ -110,6 +111,7 @@ enum projector_type {
     PROJECTOR_TYPE_PIXTRAL,
     PROJECTOR_TYPE_QWEN25VL,
     PROJECTOR_TYPE_INTERNVL,
+    PROJECTOR_TYPE_LLAMA4,
     PROJECTOR_TYPE_UNKNOWN,
 };
 
@@ -125,6 +127,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
     { PROJECTOR_TYPE_IDEFICS3,  "idefics3"},
     { PROJECTOR_TYPE_PIXTRAL,   "pixtral"},
     { PROJECTOR_TYPE_INTERNVL,  "internvl"},
+    { PROJECTOR_TYPE_LLAMA4,    "llama4"},
 };
 
 static projector_type clip_projector_type_from_string(const std::string & str) {
@@ -240,6 +243,11 @@ struct clip_image_u8_batch {
 struct clip_image_f32_batch {
     std::vector<clip_image_f32_ptr> entries;
 
+    // for llava-uhd style models, we need to know the grid size
+    // note: entries.size() == grid_x * grid_y + 1 (one overview image)
+    int grid_x = 0;
+    int grid_y = 0;
+
     clip_image_f32_batch clone() const {
         clip_image_f32_batch new_batch;
         new_batch.entries.reserve(entries.size());
@@ -358,6 +366,70 @@ static std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i) {
     }
 }
 
+//
+// debugging
+//
+
+static void print_tensor_shape(ggml_tensor * t) {
+    printf("%s.shape = [", t->name);
+    for (int i = 0; i < ggml_n_dims(t); ++i) {
+        printf("%" PRId64, t->ne[i]);
+        if (i < ggml_n_dims(t) - 1) {
+            printf(", ");
+        }
+    }
+    printf("]\n");
+}
+
+static void print_tensor_data(ggml_tensor * t, uint8_t * data, int64_t n) {
+    ggml_type type = t->type;
+    int64_t * ne = t->ne;
+    size_t * nb = t->nb;
+    for (int64_t i3 = 0; i3 < ne[3]; i3++) {
+        printf("%s.data: [\n", t->name);
+        for (int64_t i2 = 0; i2 < ne[2]; i2++) {
+            if (i2 == n && ne[2] > 2*n) {
+                printf("     ..., \n");
+                i2 = ne[2] - n;
+            }
+            printf("     [\n");
+            for (int64_t i1 = 0; i1 < ne[1]; i1++) {
+                if (i1 == n && ne[1] > 2*n) {
+                    printf("      ..., \n");
+                    i1 = ne[1] - n;
+                }
+                printf("      [");
+                for (int64_t i0 = 0; i0 < ne[0]; i0++) {
+                    if (i0 == n && ne[0] > 2*n) {
+                        printf("..., ");
+                        i0 = ne[0] - n;
+                    }
+                    size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
+                    float v;
+                    if (type == GGML_TYPE_F16) {
+                        v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]);
+                    } else if (type == GGML_TYPE_F32) {
+                        v = *(float *) &data[i];
+                    } else if (type == GGML_TYPE_I32) {
+                        v = (float) *(int32_t *) &data[i];
+                    } else if (type == GGML_TYPE_I16) {
+                        v = (float) *(int16_t *) &data[i];
+                    } else if (type == GGML_TYPE_I8) {
+                        v = (float) *(int8_t *) &data[i];
+                    } else {
+                        GGML_ABORT("fatal error");
+                    }
+                    printf("%8.4f", v);
+                    if (i0 < ne[0] - 1) printf(", ");
+                }
+                printf("],\n");
+            }
+            printf("     ],\n");
+        }
+        printf("    ]\n");
+    }
+}
+
 //
 // API used internally with mtmd
 //
index 128a95cc11f13231b42e113b843f2050b8fc7a84..eba07f6c82eba675029f43ba9f2e93aade0c7d19 100644 (file)
@@ -359,9 +359,12 @@ struct clip_ctx {
     int max_nodes = 8192;
     ggml_backend_sched_ptr sched;
 
-    clip_image_size load_image_size;
+    // for debugging
+    bool debug_graph = false;
+    std::vector<ggml_tensor *> debug_print_tensors;
 
     clip_ctx(clip_context_params & ctx_params) {
+        debug_graph = std::getenv("MTMD_DEBUG_GRAPH") != nullptr;
         backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
         if (!backend_cpu) {
             throw std::runtime_error("failed to initialize CPU backend");
@@ -440,7 +443,7 @@ struct clip_graph {
         };
         ctx0_ptr.reset(ggml_init(params));
         ctx0 = ctx0_ptr.get();
-        gf = ggml_new_graph(ctx0);
+        gf = ggml_new_graph_custom(ctx0, ctx->max_nodes, false);
     }
 
     ggml_cgraph * build_siglip() {
@@ -522,7 +525,7 @@ struct clip_graph {
         ggml_set_input(pos_w);
 
         auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
-            return build_rope_2d(ctx0, cur, pos_h, pos_w, hparams.rope_theta);
+            return build_rope_2d(ctx0, cur, pos_h, pos_w, hparams.rope_theta, true);
         };
 
         ggml_tensor * inp = build_inp();
@@ -936,6 +939,101 @@ struct clip_graph {
         return gf;
     }
 
+    ggml_cgraph * build_llama4() {
+        GGML_ASSERT(model.class_embedding != nullptr);
+        GGML_ASSERT(model.position_embeddings != nullptr);
+
+        const int n_pos = n_patches + 1; // +1 for [CLS]
+
+        // 2D input positions
+        ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos);
+        ggml_set_name(pos_h, "pos_h");
+        ggml_set_input(pos_h);
+
+        ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos);
+        ggml_set_name(pos_w, "pos_w");
+        ggml_set_input(pos_w);
+
+        ggml_tensor * inp = build_inp_raw();
+
+        // Llama4UnfoldConvolution
+        {
+            ggml_tensor * kernel = ggml_reshape_4d(ctx0, model.patch_embeddings_0,
+                                                    patch_size, patch_size, 3, n_embd);
+            inp = ggml_im2col(ctx0, kernel, inp, patch_size, patch_size, 0, 0, 1, 1, true, inp->type);
+            inp = ggml_mul_mat(ctx0, model.patch_embeddings_0, inp);
+            inp = ggml_reshape_2d(ctx0, inp, n_embd, n_patches);
+            cb(inp, "patch_conv", -1);
+        }
+
+        // add CLS token
+        inp = ggml_concat(ctx0, inp, model.class_embedding, 1);
+
+        // build ViT with 2D position embeddings
+        auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
+            // first half is X axis and second half is Y axis
+            // ref: https://github.com/huggingface/transformers/blob/40a493c7ed4f19f08eadb0639cf26d49bfa5e180/src/transformers/models/llama4/modeling_llama4.py#L1312
+            // ref: https://github.com/Blaizzy/mlx-vlm/blob/a57156aa87b33cca6e5ee6cfc14dd4ef8f611be6/mlx_vlm/models/llama4/vision.py#L441
+            return build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false);
+        };
+        ggml_tensor * cur = build_vit(
+                                inp, n_pos,
+                                NORM_TYPE_NORMAL,
+                                hparams.ffn_op,
+                                model.position_embeddings,
+                                add_pos);
+
+        // remove CLS token
+        cur = ggml_view_2d(ctx0, cur,
+            n_embd, n_patches,
+            ggml_row_size(cur->type, n_embd), 0);
+
+        // pixel shuffle
+        // based on Llama4VisionPixelShuffleMLP
+        // https://github.com/huggingface/transformers/blob/2932f318a20d9e54cc7aea052e040164d85de7d6/src/transformers/models/llama4/modeling_llama4.py#L1151
+        {
+            const int scale_factor = model.hparams.proj_scale_factor;
+            const int bsz = 1; // batch size, always 1 for now since we don't support batching
+            GGML_ASSERT(scale_factor > 0);
+            GGML_ASSERT(n_patches_x == n_patches_y); // llama4 only supports square images
+            cur = ggml_reshape_4d(ctx0, cur,
+                n_embd * scale_factor,
+                n_patches_x / scale_factor,
+                n_patches_y,
+                bsz);
+            cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
+            cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur),
+                n_embd * scale_factor * scale_factor,
+                n_patches_x / scale_factor,
+                n_patches_y / scale_factor,
+                bsz);
+            cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
+            // flatten to 2D
+            cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, cur),
+                n_embd * scale_factor * scale_factor,
+                n_patches / scale_factor / scale_factor);
+            cb(cur, "pixel_shuffle", -1);
+        }
+
+        // based on Llama4VisionMLP2 (always uses GELU activation, no bias)
+        {
+            cur = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, cur);
+            cur = ggml_gelu(ctx0, cur);
+            cur = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, cur);
+            cur = ggml_gelu(ctx0, cur);
+            cb(cur, "adapter_mlp", -1);
+        }
+
+        // Llama4MultiModalProjector
+        cur = ggml_mul_mat(ctx0, model.mm_model_proj, cur);
+        cb(cur, "projected", -1);
+
+        // build the graph
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
+
     // this graph is used by llava, granite and glm
     // due to having embedding_stack (used by granite), we cannot reuse build_vit
     ggml_cgraph * build_llava() {
@@ -1315,11 +1413,15 @@ private:
     // utility functions
     //
 
-    void cb(ggml_tensor * cur, const char * name, int il) const {
-        // TODO: implement this
-        GGML_UNUSED(cur);
-        GGML_UNUSED(name);
-        GGML_UNUSED(il);
+    void cb(ggml_tensor * cur0, const char * name, int il) const {
+        if (ctx->debug_graph) {
+            ggml_tensor * cur = ggml_cpy(ctx0, cur0, ggml_dup_tensor(ctx0, cur0));
+            std::string cur_name = il >= 0 ? std::string(name) + "_" + std::to_string(il) : name;
+            ggml_set_name(cur, cur_name.c_str());
+            ggml_set_output(cur);
+            ggml_build_forward_expand(gf, cur);
+            ctx->debug_print_tensors.push_back(cur);
+        }
     }
 
     // build vision transformer (ViT) cgraph
@@ -1630,9 +1732,10 @@ private:
     static ggml_tensor * build_rope_2d(
         ggml_context * ctx0,
         ggml_tensor * cur,
-        ggml_tensor * pos_h,
-        ggml_tensor * pos_w,
-        const float freq_base
+        ggml_tensor * pos_a, // first half
+        ggml_tensor * pos_b, // second half
+        const float freq_base,
+        const bool interleave_freq
     ) {
         const int64_t n_dim  = cur->ne[0];
         const int64_t n_head = cur->ne[1];
@@ -1646,7 +1749,9 @@ private:
         //  ^ don't ask me why, it's math! -2(2i) / n_dim == -2i / (n_dim/2)
         // then for the second half, we use freq_scale to shift the inv_freq
         //  ^ why? replace (2i) with (2i+1) in the above equation
-        const float freq_scale_odd = std::pow(freq_base, (float)-2/n_dim);
+        const float freq_scale_odd = interleave_freq
+                                    ? std::pow(freq_base, (float)-2/n_dim)
+                                    : 1.0;
 
         // first half
         ggml_tensor * first;
@@ -1659,7 +1764,7 @@ private:
             first = ggml_rope_ext(
                 ctx0,
                 first,
-                pos_h,      // positions
+                pos_a,      // positions
                 nullptr,    // freq factors
                 n_dim/2,    // n_dims
                 0, 0, freq_base,
@@ -1679,7 +1784,7 @@ private:
             second = ggml_rope_ext(
                 ctx0,
                 second,
-                pos_w,      // positions
+                pos_b,      // positions
                 nullptr,    // freq factors
                 n_dim/2,    // n_dims
                 0, 0, freq_base,
@@ -1723,6 +1828,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
             {
                 res = graph.build_internvl();
             } break;
+        case PROJECTOR_TYPE_LLAMA4:
+            {
+                res = graph.build_llama4();
+            } break;
         default:
             {
                 res = graph.build_llava();
@@ -1926,6 +2035,21 @@ struct clip_model_loader {
                         hparams.warmup_image_size = hparams.patch_size * 8;
                         get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern);
                     } break;
+                case PROJECTOR_TYPE_LLAMA4:
+                    {
+                        hparams.rope_theta = 10000.0f;
+                        get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor);
+
+                        // borrowed from llava-1.6
+                        const int isize = hparams.image_size;
+                        hparams.image_grid_pinpoints = {
+                            isize,   isize*2, // 336, 672
+                            isize*2, isize,   // 672, 336
+                            isize*2, isize*2, // 672, 672
+                            isize*3, isize,   // 1008, 336
+                            isize,   isize*3, // 336, 1008
+                        };
+                    } break;
                 default:
                     break;
             }
@@ -1946,6 +2070,10 @@ struct clip_model_loader {
             LOG_INF("%s: ffn_op:             %s\n", __func__, log_ffn_op.c_str());
             LOG_INF("%s: model size:         %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0);
             LOG_INF("%s: metadata size:      %.2f MiB\n", __func__, ggml_get_mem_size(ctx_meta.get()) / 1024.0 / 1024.0);
+
+            if (ctx_clip.proj_type == PROJECTOR_TYPE_LLAMA4) {
+                LOG_WRN("%s: llama 4 vision is known to have degraded quality: https://github.com/ggml-org/llama.cpp/pull/13282\n", __func__);
+            }
         }
     }
 
@@ -2001,7 +2129,7 @@ struct clip_model_loader {
         vision_model.patch_embeddings_0 = get_tensor(TN_PATCH_EMBD,   false);
         vision_model.patch_embeddings_1 = get_tensor(TN_PATCH_EMBD_1, false);
 
-        vision_model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, "v"), false);
+        vision_model.position_embeddings = get_tensor(TN_POS_EMBD, false);
 
         // layers
         vision_model.layers.resize(hparams.n_layer);
@@ -2182,6 +2310,12 @@ struct clip_model_loader {
                     vision_model.mm_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight"));
                     vision_model.mm_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias"));
                 } break;
+            case PROJECTOR_TYPE_LLAMA4:
+                {
+                    vision_model.mm_model_proj    = get_tensor(TN_MM_PROJECTOR);
+                    vision_model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight"));
+                    vision_model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight"));
+                } break;
             default:
                 GGML_ASSERT(false && "unknown projector type");
         }
@@ -2328,14 +2462,6 @@ struct clip_ctx * clip_init(const char * fname, struct clip_context_params ctx_p
     return ctx_clip;
 }
 
-void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size) {
-    ctx_clip->load_image_size = *load_image_size; // copy
-}
-
-struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip) {
-    return &ctx_clip->load_image_size;
-}
-
 struct clip_image_size * clip_image_size_init() {
     struct clip_image_size * load_image_size = new struct clip_image_size();
     load_image_size->width = 448;
@@ -2849,7 +2975,7 @@ private:
 
     // used by llava 1.6 with custom list of pinpoints
     static clip_image_size select_best_resolution(const std::vector<int32_t> & pinpoints, const clip_image_size & original_size) {
-        std::vector<clip_image_size> possible_resolutions;
+        std::vector<clip_image_size> possible_resolutions; // TODO @ngxson : construct this inside hparams, not here
         for (size_t i = 0; i < pinpoints.size(); i += 2) {
             possible_resolutions.push_back(clip_image_size{pinpoints[i], pinpoints[i+1]});
         }
@@ -2916,12 +3042,6 @@ private:
     }
 };
 
-// TODO @ngxson : decprecate the load_image_size singleton pattern
-int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip) {
-    const auto inst = llava_uhd::get_slice_instructions(ctx_clip, ctx_clip->load_image_size);
-    return inst.grid_size.width;
-}
-
 // returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector
 // res_imgs memory is being allocated here, previous allocations will be freed if found
 bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) {
@@ -2943,9 +3063,12 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
             normalize_image_u8_to_f32(*imgs[i], *res, ctx->image_mean, ctx->image_std);
             res_imgs->entries.push_back(std::move(res));
         }
+
+        res_imgs->grid_x = inst.grid_size.width;
+        res_imgs->grid_y = inst.grid_size.height;
         return true;
-    }
-    else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) {
+
+    else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) {
         clip_image_u8 resized;
         auto patch_size = params.patch_size * 2;
         auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, patch_size, params.image_size);
@@ -2971,8 +3094,8 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
         normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std);
         res_imgs->entries.push_back(std::move(img_f32));
         return true;
-    }
-    else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
+
+    else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
         clip_image_u8 resized_image;
         auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, params.patch_size, params.image_size);
         image_manipulation::bilinear_resize(*img, resized_image, new_size.width, new_size.height);
@@ -2980,6 +3103,22 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str
         normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std);
         res_imgs->entries.push_back(std::move(img_f32));
         return true;
+
+    } else if (ctx->proj_type == PROJECTOR_TYPE_LLAMA4) {
+        GGML_ASSERT(!params.image_grid_pinpoints.empty());
+        auto const inst = llava_uhd::get_slice_instructions(ctx, original_size);
+        std::vector<clip_image_u8_ptr> imgs = llava_uhd::slice_image(img, inst);
+
+        for (size_t i = 0; i < imgs.size(); ++i) {
+            clip_image_f32_ptr res(clip_image_f32_init());
+            normalize_image_u8_to_f32(*imgs[i], *res, ctx->image_mean, ctx->image_std);
+            res_imgs->entries.push_back(std::move(res));
+        }
+
+        res_imgs->grid_x = inst.grid_size.width;
+        res_imgs->grid_y = inst.grid_size.height;
+        return true;
+
     }
 
     // the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104)
@@ -3098,6 +3237,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
     const auto & params = ctx->vision_model.hparams;
 
     int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
+    int scale_factor = ctx->vision_model.hparams.proj_scale_factor;
 
     if (ctx->proj_type == PROJECTOR_TYPE_LDP
             || ctx->proj_type == PROJECTOR_TYPE_LDPV2
@@ -3136,6 +3276,8 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
         int n_patches_x = img->nx / params.patch_size / (n_merge > 0 ? n_merge : 1);
         int n_patches_y = img->ny / params.patch_size / (n_merge > 0 ? n_merge : 1);
         n_patches = n_patches_y*n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
+    } else if (ctx->proj_type == PROJECTOR_TYPE_LLAMA4) {
+        n_patches /= (scale_factor * scale_factor);
     }
 
     return n_patches;
@@ -3247,6 +3389,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
     }
 
     // build the inference graph
+    ctx->debug_print_tensors.clear();
     ggml_backend_sched_reset(ctx->sched.get());
     ggml_cgraph * gf = clip_image_build_graph(ctx, imgs);
     ggml_backend_sched_alloc_graph(ctx->sched.get(), gf);
@@ -3261,8 +3404,8 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
     const int patch_size    = hparams.patch_size;
     const int num_patches   = ((image_size_width / patch_size) * (image_size_height / patch_size));
     const int n_pos = num_patches + (model.class_embedding ? 1 : 0);
-    const int pos_w = ctx->load_image_size.width  / patch_size;
-    const int pos_h = ctx->load_image_size.height / patch_size;
+    const int pos_w = image_size_width  / patch_size;
+    const int pos_h = image_size_height / patch_size;
 
     const bool use_window_attn = hparams.n_wa_pattern > 0; // for qwen2.5vl
 
@@ -3528,6 +3671,23 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
             {
                 // do nothing
             } break;
+        case PROJECTOR_TYPE_LLAMA4:
+            {
+                // set the 2D positions
+                int n_patches_per_col = image_size_width / patch_size;
+                std::vector<int> pos_data(num_patches + 1, 0); // +1 for the [CLS] token
+                // last pos is always kept 0, it's for CLS
+                // dimension H
+                for (int i = 0; i < num_patches; i++) {
+                    pos_data[i] = (i / n_patches_per_col) + 1;
+                }
+                set_input_i32("pos_h", pos_data);
+                // dimension W
+                for (int i = 0; i < num_patches; i++) {
+                    pos_data[i] = (i % n_patches_per_col) + 1;
+                }
+                set_input_i32("pos_w", pos_data);
+            } break;
         default:
             GGML_ABORT("Unknown projector type");
     }
@@ -3548,6 +3708,18 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
         return false;
     }
 
+    // print debug nodes
+    if (ctx->debug_graph) {
+        LOG_INF("\n\n---\n\n");
+        LOG_INF("\n\nDebug graph:\n\n");
+        for (ggml_tensor * t : ctx->debug_print_tensors) {
+            std::vector<uint8_t> data(ggml_nbytes(t));
+            ggml_backend_tensor_get(t, data.data(), 0, ggml_nbytes(t));
+            print_tensor_shape(t);
+            print_tensor_data(t, data.data(), 3);
+        }
+    }
+
     // the last node is the embedding tensor
     ggml_tensor * embeddings = ggml_graph_node(gf, -1);
 
@@ -3596,6 +3768,8 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
             return ctx->vision_model.projection->ne[1];
         case PROJECTOR_TYPE_INTERNVL:
             return ctx->vision_model.mm_3_w->ne[1];
+        case PROJECTOR_TYPE_LLAMA4:
+            return ctx->vision_model.mm_model_proj->ne[1];
         default:
             GGML_ABORT("Unknown projector type");
     }
index 2d70eec94736f7c263d7bf273f0ee50d0bb9f61f..e7a1c0782dd6ae92e6a81bf155e5ab33cd5463ce 100644 (file)
@@ -47,10 +47,6 @@ int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 *
 // this should be equal to the embedding dimension of the text model
 int clip_n_mmproj_embd(const struct clip_ctx * ctx);
 
-int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip);
-void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size);
-struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip);
-
 struct clip_image_size      * clip_image_size_init(void);
 struct clip_image_u8        * clip_image_u8_init (void);
 struct clip_image_f32       * clip_image_f32_init(void);
index 2a852d9c19bd29fb469db74e73527d40a9c93dad..1234dbb468767fd2bc43161fadd265f3aacdcf32 100644 (file)
@@ -42,6 +42,7 @@ enum mtmd_slice_tmpl {
     MTMD_SLICE_TMPL_NONE,
     MTMD_SLICE_TMPL_MINICPMV_2_5,
     MTMD_SLICE_TMPL_MINICPMV_2_6,
+    MTMD_SLICE_TMPL_LLAMA4,
     // TODO @ngxson : add support for idefics (SmolVLM)
 };
 
@@ -64,15 +65,19 @@ struct mtmd_context {
     int n_threads;
     std::string image_marker;
 
-    // for minicpmv, we need special tokens in-between slices
+    // for llava-uhd style models, we need special tokens in-between slices
+    // minicpmv calls them "slices", llama 4 calls them "tiles"
     mtmd_slice_tmpl slice_tmpl    = MTMD_SLICE_TMPL_NONE;
     llama_token tok_ov_img_start  = LLAMA_TOKEN_NULL; // overview image
     llama_token tok_ov_img_end    = LLAMA_TOKEN_NULL; // overview image
     llama_token tok_slices_start  = LLAMA_TOKEN_NULL; // start of all slices
     llama_token tok_slices_end    = LLAMA_TOKEN_NULL; // end of all slices
-    llama_token tok_sli_img_start = LLAMA_TOKEN_NULL; // single slice
-    llama_token tok_sli_img_end   = LLAMA_TOKEN_NULL; // single slice
+    llama_token tok_sli_img_start = LLAMA_TOKEN_NULL; // single slice start
+    llama_token tok_sli_img_end   = LLAMA_TOKEN_NULL; // single slice end
+    llama_token tok_sli_img_mid   = LLAMA_TOKEN_NULL; // between 2 slices
     llama_token tok_row_end       = LLAMA_TOKEN_NULL; // end of row
+    bool        tok_row_end_trail = false;
+    bool        ov_img_first      = false;
 
     bool use_mrope = false; // for Qwen2VL, we need to use M-RoPE
 
@@ -96,6 +101,7 @@ struct mtmd_context {
 
         use_mrope = clip_is_qwen2vl(ctx_clip);
 
+        projector_type proj = clip_get_projector_type(ctx_clip);
         int minicpmv_version = clip_is_minicpmv(ctx_clip);
         if (minicpmv_version == 2) {
             // minicpmv 2.5 format:
@@ -108,6 +114,8 @@ struct mtmd_context {
             tok_sli_img_start = tok_ov_img_start;
             tok_sli_img_end   = tok_ov_img_end;
             tok_row_end       = lookup_token("\n");
+            tok_row_end_trail = false; // no trailing end-of-row token
+            ov_img_first      = true;
 
         } else if (minicpmv_version == 3 || minicpmv_version == 4) {
             // minicpmv 2.6 format:
@@ -118,9 +126,25 @@ struct mtmd_context {
             tok_sli_img_start = lookup_token("<slice>");
             tok_sli_img_end   = lookup_token("</slice>");
             tok_row_end       = lookup_token("\n");
+            tok_row_end_trail = false; // no trailing end-of-row token
+            ov_img_first      = true;
 
         } else if (minicpmv_version != 0) {
             GGML_ASSERT(false && "unsupported minicpmv version");
+        } else if (proj == PROJECTOR_TYPE_LLAMA4) {
+            // llama 4 format:
+            // <|image_start|>
+            //     (slice) <|tile_x_separator|> (slice) <|tile_x_separator|> ... <|tile_y_separator|>
+            //     (slice) <|tile_x_separator|> (slice) <|tile_x_separator|> ... <|tile_y_separator|>
+            //     ... <|tile_y_separator|>   <-- trailing end-of-row token
+            // <|image|> (overview)           <-- overview image is last
+            // <|image_end|>
+            slice_tmpl        = MTMD_SLICE_TMPL_LLAMA4;
+            tok_ov_img_start  = lookup_token("<|image|>");
+            tok_sli_img_mid   = lookup_token("<|tile_x_separator|>");
+            tok_row_end       = lookup_token("<|tile_y_separator|>");
+            tok_row_end_trail = true; // add trailing end-of-row token
+            ov_img_first      = false; // overview image is last
         }
     }
 
@@ -243,16 +267,18 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
         // https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md
         marker_modified = ctx->image_marker + "[IMG_END]";
         string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
-    }
 
-    else if (proj_type == PROJECTOR_TYPE_QWEN2VL || proj_type == PROJECTOR_TYPE_QWEN25VL) {
+    else if (proj_type == PROJECTOR_TYPE_QWEN2VL || proj_type == PROJECTOR_TYPE_QWEN25VL) {
         // <|vision_start|> ... (image embeddings) ... <|vision_end|>
         marker_modified = "<|vision_start|>" + ctx->image_marker + "<|vision_end|>";
         string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
 
-    }
+    } else if (proj_type == PROJECTOR_TYPE_LLAMA4) {
+        // (more details in mtmd_context constructor)
+        marker_modified = "<|image_start|>" + ctx->image_marker + "<|image_end|>";
+        string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
 
-    else if (proj_type == PROJECTOR_TYPE_INTERNVL) {
+    else if (proj_type == PROJECTOR_TYPE_INTERNVL) {
         // <img> ... (image embeddings) ... </img>
         marker_modified = "<img>" + ctx->image_marker + "</img>";
         string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
@@ -328,7 +354,6 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
             img_u8->ny = bitmaps[i_img]->ny;
             img_u8->buf.resize(bitmaps[i_img]->data.size());
             std::memcpy(img_u8->buf.data(), bitmaps[i_img]->data.data(), img_u8->nx * img_u8->ny * 3);
-            clip_image_size img_u8_size{img_u8->nx, img_u8->ny};
 
             // preprocess image
             clip_image_f32_batch batch_f32;
@@ -338,28 +363,40 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
                 return 2;
             }
 
-            if (ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_5 || ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_6) {
+            // handle llava-uhd style preprocessing
+            if (
+                ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_5
+                || ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_6
+                || ctx->slice_tmpl == MTMD_SLICE_TMPL_LLAMA4
+            ) {
                 // split batch into chunks of single images
                 auto chunks = split_batch_to_chunk(std::move(batch_f32), bitmaps[i_img]->id);
                 GGML_ASSERT(chunks.size() > 0);
 
-                // add overview image
-                add_text_chunk({ctx->tok_ov_img_start});
-                output->entries.emplace_back(std::move(chunks.front()));
+                auto ov_chunk = std::move(chunks.front());
                 chunks.erase(chunks.begin());
-                add_text_chunk({ctx->tok_ov_img_end});
 
-                // add slices
+                // add overview image (first)
+                if (ctx->ov_img_first) {
+                    if (ctx->tok_ov_img_start != LLAMA_TOKEN_NULL) {
+                        add_text_chunk({ctx->tok_ov_img_start});
+                    }
+                    output->entries.emplace_back(std::move(ov_chunk));
+                    if (ctx->tok_ov_img_end != LLAMA_TOKEN_NULL) {
+                        add_text_chunk({ctx->tok_ov_img_end});
+                    }
+                }
+
+                // add slices (or tiles)
                 if (!chunks.empty()) {
-                    clip_add_load_image_size(ctx->ctx_clip, &img_u8_size);
-                    int n_col = clip_uhd_num_image_embeds_col(ctx->ctx_clip);
-                    int n_row = (int)chunks.size() / n_col;
-                    GGML_ASSERT(n_row * n_col == (int)chunks.size());
+                    const int n_col = batch_f32.grid_x;
+                    const int n_row = batch_f32.grid_y;
                     if (ctx->tok_slices_start != LLAMA_TOKEN_NULL) {
                         add_text_chunk({ctx->tok_slices_start});
                     }
                     for (int y = 0; y < n_row; y++) {
                         for (int x = 0; x < n_col; x++) {
+                            const bool is_last_in_row = (x == n_col - 1);
                             if (ctx->tok_sli_img_start != LLAMA_TOKEN_NULL) {
                                 add_text_chunk({ctx->tok_sli_img_start});
                             }
@@ -367,8 +404,11 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
                             if (ctx->tok_sli_img_end != LLAMA_TOKEN_NULL) {
                                 add_text_chunk({ctx->tok_sli_img_end});
                             }
+                            if (!is_last_in_row && ctx->tok_sli_img_mid != LLAMA_TOKEN_NULL) {
+                                add_text_chunk({ctx->tok_sli_img_mid});
+                            }
                         }
-                        if (ctx->tok_row_end != LLAMA_TOKEN_NULL && y != n_row - 1) {
+                        if ((y != n_row - 1 || ctx->tok_row_end_trail) && ctx->tok_row_end != LLAMA_TOKEN_NULL) {
                             add_text_chunk({ctx->tok_row_end});
                         }
                     }
@@ -377,6 +417,17 @@ int32_t mtmd_tokenize(mtmd_context * ctx,
                     }
                 }
 
+                // add overview image (last)
+                if (!ctx->ov_img_first) {
+                    if (ctx->tok_ov_img_start != LLAMA_TOKEN_NULL) {
+                        add_text_chunk({ctx->tok_ov_img_start});
+                    }
+                    output->entries.emplace_back(std::move(ov_chunk));
+                    if (ctx->tok_ov_img_end != LLAMA_TOKEN_NULL) {
+                        add_text_chunk({ctx->tok_ov_img_end});
+                    }
+                }
+
             } else {
                 size_t n_tokens = 0;
                 for (const auto & entry : batch_f32.entries) {
@@ -427,14 +478,6 @@ int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens)
     ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd);
     bool ok = false;
 
-    // only effective for minicpmv and qwen2vl, other models will ignore load_image_size
-    {
-        clip_image_size slice_size{
-            image_tokens->batch_f32.entries[0]->nx,
-            image_tokens->batch_f32.entries[0]->ny};
-        clip_add_load_image_size(ctx->ctx_clip, &slice_size);
-    }
-
     if (clip_is_llava(ctx->ctx_clip) || clip_is_minicpmv(ctx->ctx_clip) || clip_is_glm(ctx->ctx_clip)) {
         // TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode()
         const auto & entries = image_tokens->batch_f32.entries;
index 05ac7a04d8fce36780747baa2e1bb7d3e8475b17..15a37b0d22bb46ae0bed8a06630e1b4e8df87d30 100755 (executable)
@@ -21,6 +21,13 @@ if [ "${1:-}" = "big" ]; then
     echo "Include BIG models..."
 fi
 
+RUN_HUGE_TESTS=false
+if [ "${1:-}" = "huge" ]; then
+    RUN_HUGE_TESTS=true
+    RUN_BIG_TESTS=true
+    echo "Include BIG models..."
+fi
+
 ###############
 
 arr_bin=()
@@ -42,7 +49,7 @@ add_test "llama-mtmd-cli"  "ggml-org/SmolVLM2-500M-Video-Instruct-GGUF:Q8_0"
 add_test "llama-mtmd-cli"  "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M"
 add_test "llama-mtmd-cli"  "THUDM/glm-edge-v-5b-gguf:Q4_K_M"
 add_test "llama-mtmd-cli"  "second-state/Llava-v1.5-7B-GGUF:Q2_K"            "vicuna"
-add_test "llama-mtmd-cli"  "cjpais/llava-1.6-mistral-7b-gguf:Q3_K"           "vicuna"
+add_test "llama-mtmd-cli"  "cjpais/llava-1.6-mistral-7b-gguf:Q3_K_M"         "vicuna"
 add_test "llama-mtmd-cli"  "ibm-research/granite-vision-3.2-2b-GGUF:Q4_K_M"
 add_test "llama-mtmd-cli"  "second-state/MiniCPM-Llama3-V-2_5-GGUF:Q2_K"  # model from openbmb is corrupted
 add_test "llama-mtmd-cli"  "openbmb/MiniCPM-V-2_6-gguf:Q2_K"
@@ -60,10 +67,17 @@ if [ "$RUN_BIG_TESTS" = true ]; then
     add_test "llama-mtmd-cli" "ggml-org/Qwen2-VL-7B-Instruct-GGUF:Q4_K_M"
     add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M"
     add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-7B-Instruct-GGUF:Q4_K_M"
-    add_test "llama-mtmd-cli"  "ggml-org/InternVL3-8B-Instruct-GGUF:Q4_K_M"
-    add_test "llama-mtmd-cli"  "ggml-org/InternVL3-14B-Instruct-GGUF:Q4_K_M"
+    add_test "llama-mtmd-cli" "ggml-org/InternVL3-8B-Instruct-GGUF:Q4_K_M"
+    add_test "llama-mtmd-cli" "ggml-org/InternVL3-14B-Instruct-GGUF:Q4_K_M"
     # add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-32B-Instruct-GGUF:Q4_K_M" # does not work on my mac M3 Ultra
-    # add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-72B-Instruct-GGUF:Q4_K_M" # too big
+fi
+
+# to test the huge models, run: ./tests.sh huge
+# this will run both the big and huge models
+# huge models are > 32B parameters
+if [ "$RUN_HUGE_TESTS" = true ]; then
+    add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-72B-Instruct-GGUF:Q4_K_M"
+    add_test "llama-mtmd-cli" "ggml-org/Llama-4-Scout-17B-16E-Instruct-GGUF:IQ1_S"
 fi
 
 # these models always give the wrong answer, not sure why