gguf.MODEL_TENSOR.TIME_MIX_LERP_FUSED,
gguf.MODEL_TENSOR.POSNET_NORM1,
gguf.MODEL_TENSOR.POSNET_NORM2,
+ gguf.MODEL_TENSOR.V_ENC_EMBD_POS,
)
)
or not new_name.endswith(".weight")
return super().modify_tensors(data_torch, name, bid)
+@ModelBase.register("Llama4ForConditionalGeneration")
+class Llama4VisionModel(VisionModel):
+ def set_gguf_parameters(self):
+ super().set_gguf_parameters()
+ self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.LLAMA4)
+ self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams["norm_eps"])
+ self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / self.hparams["pixel_shuffle_ratio"]))
+ assert self.hparams["hidden_act"] == "gelu"
+ self.gguf_writer.add_vision_use_gelu(True)
+
+ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+ del bid # unused
+ if "multi_modal_projector" in name or "vision_model" in name:
+ # process vision tensors
+ if "positional_embedding_vlm" in name and ".weight" not in name:
+ name += ".weight"
+ return [(self.map_tensor_name(name), data_torch)]
+ return []
+
+
@ModelBase.register("Mistral3ForConditionalGeneration")
class Mistral3Model(LlamaModel):
model_arch = gguf.MODEL_ARCH.LLAMA
(tool_name) -hf ggml-org/InternVL3-2B-Instruct-GGUF
(tool_name) -hf ggml-org/InternVL3-8B-Instruct-GGUF
(tool_name) -hf ggml-org/InternVL3-14B-Instruct-GGUF
+
+# Llama 4 Scout
+(tool_name) -hf ggml-org/Llama-4-Scout-17B-16E-Instruct-GGUF
```
V_ENC_EMBD_CLS = auto()
V_ENC_EMBD_PATCH = auto()
V_ENC_EMBD_POS = auto()
+ V_ENC_INPUT_NORM = auto()
V_ENC_ATTN_Q = auto()
V_ENC_ATTN_Q_NORM = auto()
V_ENC_ATTN_K = auto()
V_ENC_ATTN_K_NORM = auto()
V_ENC_ATTN_V = auto()
- V_ENC_INPUT_NORM = auto()
- V_ENC_OUTPUT = auto()
- V_ENC_OUTPUT_NORM = auto()
+ V_ENC_ATTN_O = auto()
+ V_ENC_ATTN_O_NORM = auto()
+ V_ENC_POST_ATTN_NORM = auto()
V_ENC_FFN_UP = auto()
V_ENC_FFN_GATE = auto()
V_ENC_FFN_DOWN = auto()
MODEL_TENSOR.V_ENC_ATTN_K_NORM: "v.blk.{bid}.attn_k_norm",
MODEL_TENSOR.V_ENC_ATTN_V: "v.blk.{bid}.attn_v",
MODEL_TENSOR.V_ENC_INPUT_NORM: "v.blk.{bid}.ln1",
- MODEL_TENSOR.V_ENC_OUTPUT: "v.blk.{bid}.attn_out",
- MODEL_TENSOR.V_ENC_OUTPUT_NORM: "v.blk.{bid}.ln2",
+ MODEL_TENSOR.V_ENC_ATTN_O: "v.blk.{bid}.attn_out",
+ MODEL_TENSOR.V_ENC_ATTN_O_NORM: "v.blk.{bid}.attn_out_norm",
+ MODEL_TENSOR.V_ENC_POST_ATTN_NORM: "v.blk.{bid}.ln2",
MODEL_TENSOR.V_ENC_FFN_UP: "v.blk.{bid}.ffn_up",
MODEL_TENSOR.V_ENC_FFN_GATE: "v.blk.{bid}.ffn_gate",
MODEL_TENSOR.V_ENC_FFN_DOWN: "v.blk.{bid}.ffn_down",
MODEL_TENSOR.V_ENC_EMBD_CLS,
MODEL_TENSOR.V_ENC_EMBD_PATCH,
MODEL_TENSOR.V_ENC_EMBD_POS,
+ MODEL_TENSOR.V_ENC_INPUT_NORM,
MODEL_TENSOR.V_ENC_ATTN_Q,
MODEL_TENSOR.V_ENC_ATTN_Q_NORM,
MODEL_TENSOR.V_ENC_ATTN_K,
MODEL_TENSOR.V_ENC_ATTN_K_NORM,
MODEL_TENSOR.V_ENC_ATTN_V,
- MODEL_TENSOR.V_ENC_INPUT_NORM,
- MODEL_TENSOR.V_ENC_OUTPUT,
- MODEL_TENSOR.V_ENC_OUTPUT_NORM,
+ MODEL_TENSOR.V_ENC_ATTN_O,
+ MODEL_TENSOR.V_ENC_ATTN_O_NORM,
+ MODEL_TENSOR.V_ENC_POST_ATTN_NORM,
MODEL_TENSOR.V_ENC_FFN_UP,
MODEL_TENSOR.V_ENC_FFN_GATE,
MODEL_TENSOR.V_ENC_FFN_DOWN,
GEMMA3 = "gemma3"
IDEFICS3 = "idefics3"
PIXTRAL = "pixtral"
+ LLAMA4 = "llama4"
QWEN2VL = "qwen2vl_merger"
QWEN25VL = "qwen2.5vl_merger"
INTERNVL = "internvl"
MODEL_TENSOR.V_MMPROJ_FC: (
"model.connector.modality_projection.proj", # SmolVLM
+ "multi_modal_projector.linear_1", # llama 4
),
MODEL_TENSOR.V_MMPROJ_MLP: (
"model.mm_projector.mlp.mlp.{bid}",
+ "vision_model.vision_adapter.mlp.fc{bid}", # llama 4
"mlp1.{bid}", # InternVL
),
MODEL_TENSOR.V_ENC_EMBD_CLS: (
"vision_tower.vision_model.embeddings.class_embedding",
+ "vision_model.class_embedding", # llama 4
),
MODEL_TENSOR.V_ENC_EMBD_PATCH: (
"vpm.embeddings.patch_embedding",
"model.vision_model.embeddings.patch_embedding", # SmolVLM
"vision_tower.patch_conv", # pixtral
+ "vision_model.patch_embedding.linear", # llama 4
"visual.patch_embed.proj", # qwen2vl
),
"vision_tower.vision_model.embeddings.position_embedding",
"vpm.embeddings.position_embedding",
"model.vision_model.embeddings.position_embedding", # SmolVLM
+ "vision_model.positional_embedding_vlm", # llama 4
),
MODEL_TENSOR.V_ENC_ATTN_Q: (
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.q_proj",
"vpm.encoder.layers.{bid}.self_attn.q_proj",
"model.vision_model.encoder.layers.{bid}.self_attn.q_proj", # SmolVLM
+ "vision_model.model.layers.{bid}.self_attn.q_proj", # llama4
"vision_tower.transformer.layers.{bid}.attention.q_proj", # pixtral
"visual.blocks.{bid}.attn.q", # qwen2vl, generated
),
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.k_proj",
"vpm.encoder.layers.{bid}.self_attn.k_proj",
"model.vision_model.encoder.layers.{bid}.self_attn.k_proj", # SmolVLM
+ "vision_model.model.layers.{bid}.self_attn.k_proj", # llama4
"vision_tower.transformer.layers.{bid}.attention.k_proj", # pixtral
"visual.blocks.{bid}.attn.k", # qwen2vl, generated
),
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.v_proj",
"vpm.encoder.layers.{bid}.self_attn.v_proj",
"model.vision_model.encoder.layers.{bid}.self_attn.v_proj", # SmolVLM
+ "vision_model.model.layers.{bid}.self_attn.v_proj", # llama4
"vision_tower.transformer.layers.{bid}.attention.v_proj", # pixtral
"visual.blocks.{bid}.attn.v", # qwen2vl, generated
),
"vpm.encoder.layers.{bid}.layer_norm1",
"model.vision_model.encoder.layers.{bid}.layer_norm1", # SmolVLM
"vision_tower.transformer.layers.{bid}.attention_norm", # pixtral
+ "vision_model.model.layers.{bid}.input_layernorm", # llama4
"visual.blocks.{bid}.norm1", # qwen2vl
),
- MODEL_TENSOR.V_ENC_OUTPUT: (
+ MODEL_TENSOR.V_ENC_ATTN_O: (
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj",
"vision_tower.vision_model.encoder.layers.{bid}.attn.proj", # InternVL
"vpm.encoder.layers.{bid}.self_attn.out_proj",
"model.vision_model.encoder.layers.{bid}.self_attn.out_proj", # SmolVLM
+ "vision_model.model.layers.{bid}.self_attn.o_proj", # llama4
"vision_tower.transformer.layers.{bid}.attention.o_proj", # pixtral
"visual.blocks.{bid}.attn.proj", # qwen2vl
),
- MODEL_TENSOR.V_ENC_OUTPUT_NORM: (
+ MODEL_TENSOR.V_ENC_POST_ATTN_NORM: (
"vision_tower.vision_model.encoder.layers.{bid}.layer_norm2",
"vision_tower.vision_model.encoder.layers.{bid}.norm2", # InternVL
"vpm.encoder.layers.{bid}.layer_norm2",
"model.vision_model.encoder.layers.{bid}.layer_norm2", # SmolVLM
+ "vision_model.model.layers.{bid}.post_attention_layernorm", # llama4
"vision_tower.transformer.layers.{bid}.ffn_norm", # pixtral
"visual.blocks.{bid}.norm2", # qwen2vl
),
"vpm.encoder.layers.{bid}.mlp.fc1",
"model.vision_model.encoder.layers.{bid}.mlp.fc1", # SmolVLM, gemma3
"vision_tower.transformer.layers.{bid}.feed_forward.up_proj", # pixtral
+ "vision_model.model.layers.{bid}.mlp.fc1", # llama4
"visual.blocks.{bid}.mlp.fc1", # qwen2vl
"visual.blocks.{bid}.mlp.up_proj", # qwen2.5vl
),
"vpm.encoder.layers.{bid}.mlp.fc2",
"model.vision_model.encoder.layers.{bid}.mlp.fc2", # SmolVLM, gemma3
"vision_tower.transformer.layers.{bid}.feed_forward.down_proj", # pixtral
+ "vision_model.model.layers.{bid}.mlp.fc2", # llama4
"visual.blocks.{bid}.mlp.fc2", # qwen2vl
"visual.blocks.{bid}.mlp.down_proj", # qwen2.5vl
),
MODEL_TENSOR.V_PRE_NORM: (
"vision_tower.vision_model.pre_layrnorm",
"vision_tower.ln_pre", # pixtral
+ "vision_model.layernorm_pre", # llama4
),
MODEL_TENSOR.V_POST_NORM: (
"vision_tower.vision_model.post_layernorm",
"model.vision_model.post_layernorm", # SmolVLM
+ "vision_model.layernorm_post", # llama4
"visual.merger.ln_q", # qwen2vl
),
#include <climits>
#include <cstdarg>
+#include <cinttypes>
#include <string>
#include <map>
#include <sstream>
// tensor name constants
//
-#define TN_POS_EMBD "%s.position_embd.weight"
+#define TN_POS_EMBD "v.position_embd.weight"
#define TN_CLASS_EMBD "v.class_embd"
#define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat
#define TN_PATCH_EMBD_1 "v.patch_embd.weight.1"
PROJECTOR_TYPE_PIXTRAL,
PROJECTOR_TYPE_QWEN25VL,
PROJECTOR_TYPE_INTERNVL,
+ PROJECTOR_TYPE_LLAMA4,
PROJECTOR_TYPE_UNKNOWN,
};
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"},
{ PROJECTOR_TYPE_INTERNVL, "internvl"},
+ { PROJECTOR_TYPE_LLAMA4, "llama4"},
};
static projector_type clip_projector_type_from_string(const std::string & str) {
struct clip_image_f32_batch {
std::vector<clip_image_f32_ptr> entries;
+ // for llava-uhd style models, we need to know the grid size
+ // note: entries.size() == grid_x * grid_y + 1 (one overview image)
+ int grid_x = 0;
+ int grid_y = 0;
+
clip_image_f32_batch clone() const {
clip_image_f32_batch new_batch;
new_batch.entries.reserve(entries.size());
}
}
+//
+// debugging
+//
+
+static void print_tensor_shape(ggml_tensor * t) {
+ printf("%s.shape = [", t->name);
+ for (int i = 0; i < ggml_n_dims(t); ++i) {
+ printf("%" PRId64, t->ne[i]);
+ if (i < ggml_n_dims(t) - 1) {
+ printf(", ");
+ }
+ }
+ printf("]\n");
+}
+
+static void print_tensor_data(ggml_tensor * t, uint8_t * data, int64_t n) {
+ ggml_type type = t->type;
+ int64_t * ne = t->ne;
+ size_t * nb = t->nb;
+ for (int64_t i3 = 0; i3 < ne[3]; i3++) {
+ printf("%s.data: [\n", t->name);
+ for (int64_t i2 = 0; i2 < ne[2]; i2++) {
+ if (i2 == n && ne[2] > 2*n) {
+ printf(" ..., \n");
+ i2 = ne[2] - n;
+ }
+ printf(" [\n");
+ for (int64_t i1 = 0; i1 < ne[1]; i1++) {
+ if (i1 == n && ne[1] > 2*n) {
+ printf(" ..., \n");
+ i1 = ne[1] - n;
+ }
+ printf(" [");
+ for (int64_t i0 = 0; i0 < ne[0]; i0++) {
+ if (i0 == n && ne[0] > 2*n) {
+ printf("..., ");
+ i0 = ne[0] - n;
+ }
+ size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
+ float v;
+ if (type == GGML_TYPE_F16) {
+ v = ggml_fp16_to_fp32(*(ggml_fp16_t *) &data[i]);
+ } else if (type == GGML_TYPE_F32) {
+ v = *(float *) &data[i];
+ } else if (type == GGML_TYPE_I32) {
+ v = (float) *(int32_t *) &data[i];
+ } else if (type == GGML_TYPE_I16) {
+ v = (float) *(int16_t *) &data[i];
+ } else if (type == GGML_TYPE_I8) {
+ v = (float) *(int8_t *) &data[i];
+ } else {
+ GGML_ABORT("fatal error");
+ }
+ printf("%8.4f", v);
+ if (i0 < ne[0] - 1) printf(", ");
+ }
+ printf("],\n");
+ }
+ printf(" ],\n");
+ }
+ printf(" ]\n");
+ }
+}
+
//
// API used internally with mtmd
//
int max_nodes = 8192;
ggml_backend_sched_ptr sched;
- clip_image_size load_image_size;
+ // for debugging
+ bool debug_graph = false;
+ std::vector<ggml_tensor *> debug_print_tensors;
clip_ctx(clip_context_params & ctx_params) {
+ debug_graph = std::getenv("MTMD_DEBUG_GRAPH") != nullptr;
backend_cpu = ggml_backend_init_by_type(GGML_BACKEND_DEVICE_TYPE_CPU, nullptr);
if (!backend_cpu) {
throw std::runtime_error("failed to initialize CPU backend");
};
ctx0_ptr.reset(ggml_init(params));
ctx0 = ctx0_ptr.get();
- gf = ggml_new_graph(ctx0);
+ gf = ggml_new_graph_custom(ctx0, ctx->max_nodes, false);
}
ggml_cgraph * build_siglip() {
ggml_set_input(pos_w);
auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
- return build_rope_2d(ctx0, cur, pos_h, pos_w, hparams.rope_theta);
+ return build_rope_2d(ctx0, cur, pos_h, pos_w, hparams.rope_theta, true);
};
ggml_tensor * inp = build_inp();
return gf;
}
+ ggml_cgraph * build_llama4() {
+ GGML_ASSERT(model.class_embedding != nullptr);
+ GGML_ASSERT(model.position_embeddings != nullptr);
+
+ const int n_pos = n_patches + 1; // +1 for [CLS]
+
+ // 2D input positions
+ ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos);
+ ggml_set_name(pos_h, "pos_h");
+ ggml_set_input(pos_h);
+
+ ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos);
+ ggml_set_name(pos_w, "pos_w");
+ ggml_set_input(pos_w);
+
+ ggml_tensor * inp = build_inp_raw();
+
+ // Llama4UnfoldConvolution
+ {
+ ggml_tensor * kernel = ggml_reshape_4d(ctx0, model.patch_embeddings_0,
+ patch_size, patch_size, 3, n_embd);
+ inp = ggml_im2col(ctx0, kernel, inp, patch_size, patch_size, 0, 0, 1, 1, true, inp->type);
+ inp = ggml_mul_mat(ctx0, model.patch_embeddings_0, inp);
+ inp = ggml_reshape_2d(ctx0, inp, n_embd, n_patches);
+ cb(inp, "patch_conv", -1);
+ }
+
+ // add CLS token
+ inp = ggml_concat(ctx0, inp, model.class_embedding, 1);
+
+ // build ViT with 2D position embeddings
+ auto add_pos = [&](ggml_tensor * cur, const clip_layer &) {
+ // first half is X axis and second half is Y axis
+ // ref: https://github.com/huggingface/transformers/blob/40a493c7ed4f19f08eadb0639cf26d49bfa5e180/src/transformers/models/llama4/modeling_llama4.py#L1312
+ // ref: https://github.com/Blaizzy/mlx-vlm/blob/a57156aa87b33cca6e5ee6cfc14dd4ef8f611be6/mlx_vlm/models/llama4/vision.py#L441
+ return build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false);
+ };
+ ggml_tensor * cur = build_vit(
+ inp, n_pos,
+ NORM_TYPE_NORMAL,
+ hparams.ffn_op,
+ model.position_embeddings,
+ add_pos);
+
+ // remove CLS token
+ cur = ggml_view_2d(ctx0, cur,
+ n_embd, n_patches,
+ ggml_row_size(cur->type, n_embd), 0);
+
+ // pixel shuffle
+ // based on Llama4VisionPixelShuffleMLP
+ // https://github.com/huggingface/transformers/blob/2932f318a20d9e54cc7aea052e040164d85de7d6/src/transformers/models/llama4/modeling_llama4.py#L1151
+ {
+ const int scale_factor = model.hparams.proj_scale_factor;
+ const int bsz = 1; // batch size, always 1 for now since we don't support batching
+ GGML_ASSERT(scale_factor > 0);
+ GGML_ASSERT(n_patches_x == n_patches_y); // llama4 only supports square images
+ cur = ggml_reshape_4d(ctx0, cur,
+ n_embd * scale_factor,
+ n_patches_x / scale_factor,
+ n_patches_y,
+ bsz);
+ cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
+ cur = ggml_reshape_4d(ctx0, ggml_cont(ctx0, cur),
+ n_embd * scale_factor * scale_factor,
+ n_patches_x / scale_factor,
+ n_patches_y / scale_factor,
+ bsz);
+ cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
+ // flatten to 2D
+ cur = ggml_reshape_2d(ctx0, ggml_cont(ctx0, cur),
+ n_embd * scale_factor * scale_factor,
+ n_patches / scale_factor / scale_factor);
+ cb(cur, "pixel_shuffle", -1);
+ }
+
+ // based on Llama4VisionMLP2 (always uses GELU activation, no bias)
+ {
+ cur = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w, cur);
+ cur = ggml_gelu(ctx0, cur);
+ cur = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, cur);
+ cur = ggml_gelu(ctx0, cur);
+ cb(cur, "adapter_mlp", -1);
+ }
+
+ // Llama4MultiModalProjector
+ cur = ggml_mul_mat(ctx0, model.mm_model_proj, cur);
+ cb(cur, "projected", -1);
+
+ // build the graph
+ ggml_build_forward_expand(gf, cur);
+
+ return gf;
+ }
+
// this graph is used by llava, granite and glm
// due to having embedding_stack (used by granite), we cannot reuse build_vit
ggml_cgraph * build_llava() {
// utility functions
//
- void cb(ggml_tensor * cur, const char * name, int il) const {
- // TODO: implement this
- GGML_UNUSED(cur);
- GGML_UNUSED(name);
- GGML_UNUSED(il);
+ void cb(ggml_tensor * cur0, const char * name, int il) const {
+ if (ctx->debug_graph) {
+ ggml_tensor * cur = ggml_cpy(ctx0, cur0, ggml_dup_tensor(ctx0, cur0));
+ std::string cur_name = il >= 0 ? std::string(name) + "_" + std::to_string(il) : name;
+ ggml_set_name(cur, cur_name.c_str());
+ ggml_set_output(cur);
+ ggml_build_forward_expand(gf, cur);
+ ctx->debug_print_tensors.push_back(cur);
+ }
}
// build vision transformer (ViT) cgraph
static ggml_tensor * build_rope_2d(
ggml_context * ctx0,
ggml_tensor * cur,
- ggml_tensor * pos_h,
- ggml_tensor * pos_w,
- const float freq_base
+ ggml_tensor * pos_a, // first half
+ ggml_tensor * pos_b, // second half
+ const float freq_base,
+ const bool interleave_freq
) {
const int64_t n_dim = cur->ne[0];
const int64_t n_head = cur->ne[1];
// ^ don't ask me why, it's math! -2(2i) / n_dim == -2i / (n_dim/2)
// then for the second half, we use freq_scale to shift the inv_freq
// ^ why? replace (2i) with (2i+1) in the above equation
- const float freq_scale_odd = std::pow(freq_base, (float)-2/n_dim);
+ const float freq_scale_odd = interleave_freq
+ ? std::pow(freq_base, (float)-2/n_dim)
+ : 1.0;
// first half
ggml_tensor * first;
first = ggml_rope_ext(
ctx0,
first,
- pos_h, // positions
+ pos_a, // positions
nullptr, // freq factors
n_dim/2, // n_dims
0, 0, freq_base,
second = ggml_rope_ext(
ctx0,
second,
- pos_w, // positions
+ pos_b, // positions
nullptr, // freq factors
n_dim/2, // n_dims
0, 0, freq_base,
{
res = graph.build_internvl();
} break;
+ case PROJECTOR_TYPE_LLAMA4:
+ {
+ res = graph.build_llama4();
+ } break;
default:
{
res = graph.build_llava();
hparams.warmup_image_size = hparams.patch_size * 8;
get_u32(KEY_WIN_ATTN_PATTERN, hparams.n_wa_pattern);
} break;
+ case PROJECTOR_TYPE_LLAMA4:
+ {
+ hparams.rope_theta = 10000.0f;
+ get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor);
+
+ // borrowed from llava-1.6
+ const int isize = hparams.image_size;
+ hparams.image_grid_pinpoints = {
+ isize, isize*2, // 336, 672
+ isize*2, isize, // 672, 336
+ isize*2, isize*2, // 672, 672
+ isize*3, isize, // 1008, 336
+ isize, isize*3, // 336, 1008
+ };
+ } break;
default:
break;
}
LOG_INF("%s: ffn_op: %s\n", __func__, log_ffn_op.c_str());
LOG_INF("%s: model size: %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0);
LOG_INF("%s: metadata size: %.2f MiB\n", __func__, ggml_get_mem_size(ctx_meta.get()) / 1024.0 / 1024.0);
+
+ if (ctx_clip.proj_type == PROJECTOR_TYPE_LLAMA4) {
+ LOG_WRN("%s: llama 4 vision is known to have degraded quality: https://github.com/ggml-org/llama.cpp/pull/13282\n", __func__);
+ }
}
}
vision_model.patch_embeddings_0 = get_tensor(TN_PATCH_EMBD, false);
vision_model.patch_embeddings_1 = get_tensor(TN_PATCH_EMBD_1, false);
- vision_model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, "v"), false);
+ vision_model.position_embeddings = get_tensor(TN_POS_EMBD, false);
// layers
vision_model.layers.resize(hparams.n_layer);
vision_model.mm_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight"));
vision_model.mm_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias"));
} break;
+ case PROJECTOR_TYPE_LLAMA4:
+ {
+ vision_model.mm_model_proj = get_tensor(TN_MM_PROJECTOR);
+ vision_model.mm_model_mlp_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight"));
+ vision_model.mm_model_mlp_2_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 2, "weight"));
+ } break;
default:
GGML_ASSERT(false && "unknown projector type");
}
return ctx_clip;
}
-void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size) {
- ctx_clip->load_image_size = *load_image_size; // copy
-}
-
-struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip) {
- return &ctx_clip->load_image_size;
-}
-
struct clip_image_size * clip_image_size_init() {
struct clip_image_size * load_image_size = new struct clip_image_size();
load_image_size->width = 448;
// used by llava 1.6 with custom list of pinpoints
static clip_image_size select_best_resolution(const std::vector<int32_t> & pinpoints, const clip_image_size & original_size) {
- std::vector<clip_image_size> possible_resolutions;
+ std::vector<clip_image_size> possible_resolutions; // TODO @ngxson : construct this inside hparams, not here
for (size_t i = 0; i < pinpoints.size(); i += 2) {
possible_resolutions.push_back(clip_image_size{pinpoints[i], pinpoints[i+1]});
}
}
};
-// TODO @ngxson : decprecate the load_image_size singleton pattern
-int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip) {
- const auto inst = llava_uhd::get_slice_instructions(ctx_clip, ctx_clip->load_image_size);
- return inst.grid_size.width;
-}
-
// returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector
// res_imgs memory is being allocated here, previous allocations will be freed if found
bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) {
normalize_image_u8_to_f32(*imgs[i], *res, ctx->image_mean, ctx->image_std);
res_imgs->entries.push_back(std::move(res));
}
+
+ res_imgs->grid_x = inst.grid_size.width;
+ res_imgs->grid_y = inst.grid_size.height;
return true;
- }
- else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) {
+
+ } else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) {
clip_image_u8 resized;
auto patch_size = params.patch_size * 2;
auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, patch_size, params.image_size);
normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std);
res_imgs->entries.push_back(std::move(img_f32));
return true;
- }
- else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
+
+ } else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
clip_image_u8 resized_image;
auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, params.patch_size, params.image_size);
image_manipulation::bilinear_resize(*img, resized_image, new_size.width, new_size.height);
normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std);
res_imgs->entries.push_back(std::move(img_f32));
return true;
+
+ } else if (ctx->proj_type == PROJECTOR_TYPE_LLAMA4) {
+ GGML_ASSERT(!params.image_grid_pinpoints.empty());
+ auto const inst = llava_uhd::get_slice_instructions(ctx, original_size);
+ std::vector<clip_image_u8_ptr> imgs = llava_uhd::slice_image(img, inst);
+
+ for (size_t i = 0; i < imgs.size(); ++i) {
+ clip_image_f32_ptr res(clip_image_f32_init());
+ normalize_image_u8_to_f32(*imgs[i], *res, ctx->image_mean, ctx->image_std);
+ res_imgs->entries.push_back(std::move(res));
+ }
+
+ res_imgs->grid_x = inst.grid_size.width;
+ res_imgs->grid_y = inst.grid_size.height;
+ return true;
+
}
// the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104)
const auto & params = ctx->vision_model.hparams;
int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
+ int scale_factor = ctx->vision_model.hparams.proj_scale_factor;
if (ctx->proj_type == PROJECTOR_TYPE_LDP
|| ctx->proj_type == PROJECTOR_TYPE_LDPV2
int n_patches_x = img->nx / params.patch_size / (n_merge > 0 ? n_merge : 1);
int n_patches_y = img->ny / params.patch_size / (n_merge > 0 ? n_merge : 1);
n_patches = n_patches_y*n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
+ } else if (ctx->proj_type == PROJECTOR_TYPE_LLAMA4) {
+ n_patches /= (scale_factor * scale_factor);
}
return n_patches;
}
// build the inference graph
+ ctx->debug_print_tensors.clear();
ggml_backend_sched_reset(ctx->sched.get());
ggml_cgraph * gf = clip_image_build_graph(ctx, imgs);
ggml_backend_sched_alloc_graph(ctx->sched.get(), gf);
const int patch_size = hparams.patch_size;
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
const int n_pos = num_patches + (model.class_embedding ? 1 : 0);
- const int pos_w = ctx->load_image_size.width / patch_size;
- const int pos_h = ctx->load_image_size.height / patch_size;
+ const int pos_w = image_size_width / patch_size;
+ const int pos_h = image_size_height / patch_size;
const bool use_window_attn = hparams.n_wa_pattern > 0; // for qwen2.5vl
{
// do nothing
} break;
+ case PROJECTOR_TYPE_LLAMA4:
+ {
+ // set the 2D positions
+ int n_patches_per_col = image_size_width / patch_size;
+ std::vector<int> pos_data(num_patches + 1, 0); // +1 for the [CLS] token
+ // last pos is always kept 0, it's for CLS
+ // dimension H
+ for (int i = 0; i < num_patches; i++) {
+ pos_data[i] = (i / n_patches_per_col) + 1;
+ }
+ set_input_i32("pos_h", pos_data);
+ // dimension W
+ for (int i = 0; i < num_patches; i++) {
+ pos_data[i] = (i % n_patches_per_col) + 1;
+ }
+ set_input_i32("pos_w", pos_data);
+ } break;
default:
GGML_ABORT("Unknown projector type");
}
return false;
}
+ // print debug nodes
+ if (ctx->debug_graph) {
+ LOG_INF("\n\n---\n\n");
+ LOG_INF("\n\nDebug graph:\n\n");
+ for (ggml_tensor * t : ctx->debug_print_tensors) {
+ std::vector<uint8_t> data(ggml_nbytes(t));
+ ggml_backend_tensor_get(t, data.data(), 0, ggml_nbytes(t));
+ print_tensor_shape(t);
+ print_tensor_data(t, data.data(), 3);
+ }
+ }
+
// the last node is the embedding tensor
ggml_tensor * embeddings = ggml_graph_node(gf, -1);
return ctx->vision_model.projection->ne[1];
case PROJECTOR_TYPE_INTERNVL:
return ctx->vision_model.mm_3_w->ne[1];
+ case PROJECTOR_TYPE_LLAMA4:
+ return ctx->vision_model.mm_model_proj->ne[1];
default:
GGML_ABORT("Unknown projector type");
}
// this should be equal to the embedding dimension of the text model
int clip_n_mmproj_embd(const struct clip_ctx * ctx);
-int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip);
-void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size);
-struct clip_image_size * clip_get_load_image_size(struct clip_ctx * ctx_clip);
-
struct clip_image_size * clip_image_size_init(void);
struct clip_image_u8 * clip_image_u8_init (void);
struct clip_image_f32 * clip_image_f32_init(void);
MTMD_SLICE_TMPL_NONE,
MTMD_SLICE_TMPL_MINICPMV_2_5,
MTMD_SLICE_TMPL_MINICPMV_2_6,
+ MTMD_SLICE_TMPL_LLAMA4,
// TODO @ngxson : add support for idefics (SmolVLM)
};
int n_threads;
std::string image_marker;
- // for minicpmv, we need special tokens in-between slices
+ // for llava-uhd style models, we need special tokens in-between slices
+ // minicpmv calls them "slices", llama 4 calls them "tiles"
mtmd_slice_tmpl slice_tmpl = MTMD_SLICE_TMPL_NONE;
llama_token tok_ov_img_start = LLAMA_TOKEN_NULL; // overview image
llama_token tok_ov_img_end = LLAMA_TOKEN_NULL; // overview image
llama_token tok_slices_start = LLAMA_TOKEN_NULL; // start of all slices
llama_token tok_slices_end = LLAMA_TOKEN_NULL; // end of all slices
- llama_token tok_sli_img_start = LLAMA_TOKEN_NULL; // single slice
- llama_token tok_sli_img_end = LLAMA_TOKEN_NULL; // single slice
+ llama_token tok_sli_img_start = LLAMA_TOKEN_NULL; // single slice start
+ llama_token tok_sli_img_end = LLAMA_TOKEN_NULL; // single slice end
+ llama_token tok_sli_img_mid = LLAMA_TOKEN_NULL; // between 2 slices
llama_token tok_row_end = LLAMA_TOKEN_NULL; // end of row
+ bool tok_row_end_trail = false;
+ bool ov_img_first = false;
bool use_mrope = false; // for Qwen2VL, we need to use M-RoPE
use_mrope = clip_is_qwen2vl(ctx_clip);
+ projector_type proj = clip_get_projector_type(ctx_clip);
int minicpmv_version = clip_is_minicpmv(ctx_clip);
if (minicpmv_version == 2) {
// minicpmv 2.5 format:
tok_sli_img_start = tok_ov_img_start;
tok_sli_img_end = tok_ov_img_end;
tok_row_end = lookup_token("\n");
+ tok_row_end_trail = false; // no trailing end-of-row token
+ ov_img_first = true;
} else if (minicpmv_version == 3 || minicpmv_version == 4) {
// minicpmv 2.6 format:
tok_sli_img_start = lookup_token("<slice>");
tok_sli_img_end = lookup_token("</slice>");
tok_row_end = lookup_token("\n");
+ tok_row_end_trail = false; // no trailing end-of-row token
+ ov_img_first = true;
} else if (minicpmv_version != 0) {
GGML_ASSERT(false && "unsupported minicpmv version");
+ } else if (proj == PROJECTOR_TYPE_LLAMA4) {
+ // llama 4 format:
+ // <|image_start|>
+ // (slice) <|tile_x_separator|> (slice) <|tile_x_separator|> ... <|tile_y_separator|>
+ // (slice) <|tile_x_separator|> (slice) <|tile_x_separator|> ... <|tile_y_separator|>
+ // ... <|tile_y_separator|> <-- trailing end-of-row token
+ // <|image|> (overview) <-- overview image is last
+ // <|image_end|>
+ slice_tmpl = MTMD_SLICE_TMPL_LLAMA4;
+ tok_ov_img_start = lookup_token("<|image|>");
+ tok_sli_img_mid = lookup_token("<|tile_x_separator|>");
+ tok_row_end = lookup_token("<|tile_y_separator|>");
+ tok_row_end_trail = true; // add trailing end-of-row token
+ ov_img_first = false; // overview image is last
}
}
// https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md
marker_modified = ctx->image_marker + "[IMG_END]";
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
- }
- else if (proj_type == PROJECTOR_TYPE_QWEN2VL || proj_type == PROJECTOR_TYPE_QWEN25VL) {
+ } else if (proj_type == PROJECTOR_TYPE_QWEN2VL || proj_type == PROJECTOR_TYPE_QWEN25VL) {
// <|vision_start|> ... (image embeddings) ... <|vision_end|>
marker_modified = "<|vision_start|>" + ctx->image_marker + "<|vision_end|>";
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
- }
+ } else if (proj_type == PROJECTOR_TYPE_LLAMA4) {
+ // (more details in mtmd_context constructor)
+ marker_modified = "<|image_start|>" + ctx->image_marker + "<|image_end|>";
+ string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
- else if (proj_type == PROJECTOR_TYPE_INTERNVL) {
+ } else if (proj_type == PROJECTOR_TYPE_INTERNVL) {
// <img> ... (image embeddings) ... </img>
marker_modified = "<img>" + ctx->image_marker + "</img>";
string_replace_all(prompt_modified, ctx->image_marker, marker_modified);
img_u8->ny = bitmaps[i_img]->ny;
img_u8->buf.resize(bitmaps[i_img]->data.size());
std::memcpy(img_u8->buf.data(), bitmaps[i_img]->data.data(), img_u8->nx * img_u8->ny * 3);
- clip_image_size img_u8_size{img_u8->nx, img_u8->ny};
// preprocess image
clip_image_f32_batch batch_f32;
return 2;
}
- if (ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_5 || ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_6) {
+ // handle llava-uhd style preprocessing
+ if (
+ ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_5
+ || ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_6
+ || ctx->slice_tmpl == MTMD_SLICE_TMPL_LLAMA4
+ ) {
// split batch into chunks of single images
auto chunks = split_batch_to_chunk(std::move(batch_f32), bitmaps[i_img]->id);
GGML_ASSERT(chunks.size() > 0);
- // add overview image
- add_text_chunk({ctx->tok_ov_img_start});
- output->entries.emplace_back(std::move(chunks.front()));
+ auto ov_chunk = std::move(chunks.front());
chunks.erase(chunks.begin());
- add_text_chunk({ctx->tok_ov_img_end});
- // add slices
+ // add overview image (first)
+ if (ctx->ov_img_first) {
+ if (ctx->tok_ov_img_start != LLAMA_TOKEN_NULL) {
+ add_text_chunk({ctx->tok_ov_img_start});
+ }
+ output->entries.emplace_back(std::move(ov_chunk));
+ if (ctx->tok_ov_img_end != LLAMA_TOKEN_NULL) {
+ add_text_chunk({ctx->tok_ov_img_end});
+ }
+ }
+
+ // add slices (or tiles)
if (!chunks.empty()) {
- clip_add_load_image_size(ctx->ctx_clip, &img_u8_size);
- int n_col = clip_uhd_num_image_embeds_col(ctx->ctx_clip);
- int n_row = (int)chunks.size() / n_col;
- GGML_ASSERT(n_row * n_col == (int)chunks.size());
+ const int n_col = batch_f32.grid_x;
+ const int n_row = batch_f32.grid_y;
if (ctx->tok_slices_start != LLAMA_TOKEN_NULL) {
add_text_chunk({ctx->tok_slices_start});
}
for (int y = 0; y < n_row; y++) {
for (int x = 0; x < n_col; x++) {
+ const bool is_last_in_row = (x == n_col - 1);
if (ctx->tok_sli_img_start != LLAMA_TOKEN_NULL) {
add_text_chunk({ctx->tok_sli_img_start});
}
if (ctx->tok_sli_img_end != LLAMA_TOKEN_NULL) {
add_text_chunk({ctx->tok_sli_img_end});
}
+ if (!is_last_in_row && ctx->tok_sli_img_mid != LLAMA_TOKEN_NULL) {
+ add_text_chunk({ctx->tok_sli_img_mid});
+ }
}
- if (ctx->tok_row_end != LLAMA_TOKEN_NULL && y != n_row - 1) {
+ if ((y != n_row - 1 || ctx->tok_row_end_trail) && ctx->tok_row_end != LLAMA_TOKEN_NULL) {
add_text_chunk({ctx->tok_row_end});
}
}
}
}
+ // add overview image (last)
+ if (!ctx->ov_img_first) {
+ if (ctx->tok_ov_img_start != LLAMA_TOKEN_NULL) {
+ add_text_chunk({ctx->tok_ov_img_start});
+ }
+ output->entries.emplace_back(std::move(ov_chunk));
+ if (ctx->tok_ov_img_end != LLAMA_TOKEN_NULL) {
+ add_text_chunk({ctx->tok_ov_img_end});
+ }
+ }
+
} else {
size_t n_tokens = 0;
for (const auto & entry : batch_f32.entries) {
ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd);
bool ok = false;
- // only effective for minicpmv and qwen2vl, other models will ignore load_image_size
- {
- clip_image_size slice_size{
- image_tokens->batch_f32.entries[0]->nx,
- image_tokens->batch_f32.entries[0]->ny};
- clip_add_load_image_size(ctx->ctx_clip, &slice_size);
- }
-
if (clip_is_llava(ctx->ctx_clip) || clip_is_minicpmv(ctx->ctx_clip) || clip_is_glm(ctx->ctx_clip)) {
// TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode()
const auto & entries = image_tokens->batch_f32.entries;
echo "Include BIG models..."
fi
+RUN_HUGE_TESTS=false
+if [ "${1:-}" = "huge" ]; then
+ RUN_HUGE_TESTS=true
+ RUN_BIG_TESTS=true
+ echo "Include BIG models..."
+fi
+
###############
arr_bin=()
add_test "llama-mtmd-cli" "ggml-org/gemma-3-4b-it-GGUF:Q4_K_M"
add_test "llama-mtmd-cli" "THUDM/glm-edge-v-5b-gguf:Q4_K_M"
add_test "llama-mtmd-cli" "second-state/Llava-v1.5-7B-GGUF:Q2_K" "vicuna"
-add_test "llama-mtmd-cli" "cjpais/llava-1.6-mistral-7b-gguf:Q3_K" "vicuna"
+add_test "llama-mtmd-cli" "cjpais/llava-1.6-mistral-7b-gguf:Q3_K_M" "vicuna"
add_test "llama-mtmd-cli" "ibm-research/granite-vision-3.2-2b-GGUF:Q4_K_M"
add_test "llama-mtmd-cli" "second-state/MiniCPM-Llama3-V-2_5-GGUF:Q2_K" # model from openbmb is corrupted
add_test "llama-mtmd-cli" "openbmb/MiniCPM-V-2_6-gguf:Q2_K"
add_test "llama-mtmd-cli" "ggml-org/Qwen2-VL-7B-Instruct-GGUF:Q4_K_M"
add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-3B-Instruct-GGUF:Q4_K_M"
add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-7B-Instruct-GGUF:Q4_K_M"
- add_test "llama-mtmd-cli" "ggml-org/InternVL3-8B-Instruct-GGUF:Q4_K_M"
- add_test "llama-mtmd-cli" "ggml-org/InternVL3-14B-Instruct-GGUF:Q4_K_M"
+ add_test "llama-mtmd-cli" "ggml-org/InternVL3-8B-Instruct-GGUF:Q4_K_M"
+ add_test "llama-mtmd-cli" "ggml-org/InternVL3-14B-Instruct-GGUF:Q4_K_M"
# add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-32B-Instruct-GGUF:Q4_K_M" # does not work on my mac M3 Ultra
- # add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-72B-Instruct-GGUF:Q4_K_M" # too big
+fi
+
+# to test the huge models, run: ./tests.sh huge
+# this will run both the big and huge models
+# huge models are > 32B parameters
+if [ "$RUN_HUGE_TESTS" = true ]; then
+ add_test "llama-mtmd-cli" "ggml-org/Qwen2.5-VL-72B-Instruct-GGUF:Q4_K_M"
+ add_test "llama-mtmd-cli" "ggml-org/Llama-4-Scout-17B-16E-Instruct-GGUF:IQ1_S"
fi
# these models always give the wrong answer, not sure why