if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2":
# ref: https://huggingface.co/THUDM/glm-4-9b-hf
res = "glm4"
+ if chkhsh == "0e9433cbbb161f89e264eb32e8e64bfe69e834973ffca5d41d3948a604a3e2a3":
+ # ref: https://huggingface.co/mistral-community/pixtral-12b
+ res = "pixtral"
if res is None:
logger.warning("\n")
"MistralForCausalLM",
"MixtralForCausalLM",
"Idefics3ForConditionalGeneration",
- "SmolVLMForConditionalGeneration")
+ "SmolVLMForConditionalGeneration",
+ "LlavaForConditionalGeneration")
class LlamaModel(TextModel):
model_arch = gguf.MODEL_ARCH.LLAMA
undo_permute = True
# fix for SmolVLM2, missing `num_attention_heads` in config.json
if self.hparams["architectures"][0] == "SmolVLMForConditionalGeneration":
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
+ # fix for Pixtral, missing `num_attention_heads` in config.json
+ if self.hparams["architectures"][0] == "LlavaForConditionalGeneration" \
+ and self.hparams.get("model_type") == "mistral":
+ self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
def set_vocab(self):
try:
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams.get("num_key_value_heads")
- is_vision_tensor = "vision_tower" in name or "vision_model" in name or "model.connector" in name
+ is_vision_tensor = "vision_tower" in name \
+ or "vision_model" in name \
+ or "model.connector" in name \
+ or "multi_modal_projector" in name
if is_vision_tensor:
return [] # skip vision tensors
elif name.startswith("model.text_model"):
name = name.replace("text_model.", "") # for SmolVLM
+ elif name.startswith("language_model."):
+ name = name.replace("language_model.", "") # for the rest
if self.undo_permute:
if name.endswith(("q_proj.weight", "q_proj.bias")):
raise ValueError(f"Unprocessed experts: {experts}")
+@ModelBase.register("LlavaForConditionalGeneration")
+class LlavaVisionModel(VisionModel):
+ img_break_tok_id = -1
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ if self.hparams["model_type"] == "pixtral":
+ # fix missing config.json values
+ self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 16)
+ self.hparams["num_hidden_layers"] = self.hparams.get("num_hidden_layers", 24)
+ self.hparams["intermediate_size"] = self.hparams.get("intermediate_size", 4096)
+ self.hparams["hidden_size"] = self.hparams.get("hidden_size", 1024)
+ self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5)
+ self.img_break_tok_id = 12 # see tokenizer_config.json
+ else:
+ raise ValueError(f"Unsupported model type: {self.hparams['model_type']}")
+
+ def set_gguf_parameters(self):
+ super().set_gguf_parameters()
+ hparams = self.hparams
+ if hparams["model_type"] == "pixtral":
+ self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.PIXTRAL)
+ # default values below are taken from HF tranformers code
+ self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
+ self.gguf_writer.add_vision_use_silu(True)
+
+ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+ del bid # unused
+ n_head = self.hparams["num_attention_heads"]
+ n_kv_head = n_head
+
+ if name.startswith("multi_modal_projector.") or name.startswith("vision_tower."):
+ # process vision tensors
+ if name.endswith(("q_proj.weight", "q_proj.bias")):
+ data_torch = LlamaModel.permute(data_torch, n_head, n_head)
+ if name.endswith(("k_proj.weight", "k_proj.bias")):
+ data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
+ return [(self.map_tensor_name(name), data_torch)]
+
+ if self.img_break_tok_id > 0 and "embed_tokens.weight" in name:
+ logger.info(f"Extracting [IMG_BREAK] token embedding from {name}")
+ # for pixtral model, we need to extract the [IMG_BREAK] token embedding
+ img_break_embd = data_torch[self.img_break_tok_id]
+ name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK]
+ return [(self.map_tensor_name(name), img_break_embd)]
+
+ return [] # skip other tensors
+
+
@ModelBase.register("Idefics3ForConditionalGeneration", "SmolVLMForConditionalGeneration")
class SmolVLMModel(VisionModel):
def __init__(self, *args, **kwargs):
patch_merge_type mm_patch_merge_type = PATCH_MERGE_FLAT;
- float eps;
+ float eps = 1e-6;
+ float rope_theta = 0.0;
std::vector<int32_t> image_grid_pinpoints;
int32_t image_crop_resolution;
struct ggml_tensor * ln_1_b = nullptr;
// ff
- struct ggml_tensor * ff_i_w = nullptr;
- struct ggml_tensor * ff_i_b = nullptr;
-
- struct ggml_tensor * ff_o_w = nullptr;
- struct ggml_tensor * ff_o_b = nullptr;
+ struct ggml_tensor * ff_i_w = nullptr; // legacy naming
+ struct ggml_tensor * ff_i_b = nullptr; // legacy naming
+ struct ggml_tensor * ff_o_w = nullptr; // legacy naming
+ struct ggml_tensor * ff_o_b = nullptr; // legacy naming
+
+ struct ggml_tensor * ff_up_w = nullptr;
+ struct ggml_tensor * ff_up_b = nullptr;
+ struct ggml_tensor * ff_gate_w = nullptr;
+ struct ggml_tensor * ff_gate_b = nullptr;
+ struct ggml_tensor * ff_down_w = nullptr;
+ struct ggml_tensor * ff_down_b = nullptr;
// layernorm 2
struct ggml_tensor * ln_2_w = nullptr;
// gemma3
struct ggml_tensor * mm_input_proj_w = nullptr;
struct ggml_tensor * mm_soft_emb_norm_w = nullptr;
+
+ // pixtral
+ struct ggml_tensor * token_embd_img_break = nullptr;
};
struct clip_ctx {
ggml_backend_t backend_cpu;
ggml_backend_buffer_ptr buf;
+ int max_nodes = 8192;
ggml_backend_sched_ptr sched;
clip_image_size load_image_size;
return gf;
}
+// implementation of the 2D RoPE without adding a new op in ggml
+static ggml_tensor * build_rope_2d(
+ ggml_cgraph * gf,
+ ggml_context * ctx0,
+ ggml_tensor * cur,
+ ggml_tensor * pos_h,
+ ggml_tensor * pos_w,
+ const float freq_base
+) {
+ ggml_tensor * tmp;
+ const int64_t n_dim = cur->ne[0];
+ const int64_t n_head = cur->ne[1];
+ const int64_t n_pos = cur->ne[2];
+
+ // for example, if we have cur tensor of shape (n_dim=8, n_head, n_pos)
+ // we will have a list of 4 inv_freq: 1e-0, 1e-1, 1e-2, 1e-3
+ // first half of cur will use 1e-0, 1e-2 (even)
+ // second half of cur will use 1e-1, 1e-3 (odd)
+ //
+ // for the first half, the trick here is to rotate n_dim/2, so inv_freq will be even
+ // ^ don't ask me why, it's math! -2(2i) / n_dim == -2i / (n_dim/2)
+ // then for the second half, we use freq_scale to shift the inv_freq
+ // ^ why? replace (2i) with (2i+1) in the above equation
+ const float freq_scale_odd = std::pow(freq_base, (float)-2/n_dim);
+
+ // first half
+ {
+ cur = ggml_rope_ext_inplace(
+ ctx0,
+ cur,
+ pos_h, // positions
+ nullptr, // freq factors
+ n_dim/2, // n_dims
+ 0, 0, freq_base,
+ 1.0f, 0.0f, 1.0f, 0.0f, 0.0f
+ );
+ }
+
+ // second half
+ {
+ tmp = ggml_view_3d(ctx0, cur,
+ n_dim/2, n_head, n_pos,
+ ggml_row_size(cur->type, n_dim),
+ ggml_row_size(cur->type, n_dim*n_head),
+ n_dim/2 * ggml_element_size(cur));
+ tmp = ggml_rope_ext_inplace(
+ ctx0,
+ tmp,
+ pos_w, // positions
+ nullptr, // freq factors
+ n_dim/2, // n_dims
+ 0, 0, freq_base,
+ freq_scale_odd,
+ 0.0f, 1.0f, 0.0f, 0.0f
+ );
+ // calculate inplace (modify cur directly)
+ ggml_build_forward_expand(gf, tmp);
+ }
+
+ return cur;
+}
+
+static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
+ const auto & model = ctx->vision_model;
+ const auto & hparams = model.hparams;
+
+ GGML_ASSERT(ctx->proj_type == PROJECTOR_TYPE_PIXTRAL);
+ GGML_ASSERT(imgs.entries.size() == 1); // batch_size == 1
+
+ int image_size_width = imgs.entries[0]->nx;
+ int image_size_height = imgs.entries[0]->ny;
+
+ const int patch_size = hparams.patch_size;
+ const int n_patches_x = image_size_width / patch_size;
+ const int n_patches_y = image_size_height / patch_size;
+ const int num_patches = n_patches_x * n_patches_y;
+ const int hidden_size = hparams.hidden_size;
+ const int n_head = hparams.n_head;
+ const int d_head = hidden_size / n_head;
+ const int n_layer = hparams.n_layer;
+ const float eps = hparams.eps;
+
+ struct ggml_init_params params = {
+ /*.mem_size =*/ ctx->buf_compute_meta.size(),
+ /*.mem_buffer =*/ ctx->buf_compute_meta.data(),
+ /*.no_alloc =*/ true,
+ };
+
+ ggml_context_ptr ctx0_ptr(ggml_init(params));
+ auto ctx0 = ctx0_ptr.get();
+
+ struct ggml_cgraph * gf = ggml_new_graph(ctx0);
+
+ // input raw
+ struct ggml_tensor * inp_raw = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, image_size_width, image_size_height, 3);
+ ggml_set_name(inp_raw, "inp_raw");
+ ggml_set_input(inp_raw);
+
+ // 2D input positions
+ struct ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
+ ggml_set_name(pos_h, "pos_h");
+ ggml_set_input(pos_h);
+ struct ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_patches);
+ ggml_set_name(pos_w, "pos_w");
+ ggml_set_input(pos_w);
+
+ struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
+ inp = ggml_reshape_2d(ctx0, inp, num_patches, hidden_size);
+ inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp));
+
+ struct ggml_tensor * embeddings = inp;
+
+ // pre-layer norm
+ embeddings = ggml_mul(ctx0, ggml_rms_norm(ctx0, embeddings, eps), model.pre_ln_w);
+
+ // loop over layers
+ for (int il = 0; il < n_layer; il++) {
+ struct ggml_tensor * cur = embeddings;
+
+ // pre-attention norm
+ cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.layers[il].ln_1_w);
+
+ // self-attention
+ {
+ struct ggml_tensor * Q = ggml_mul_mat(ctx0, model.layers[il].q_w, cur);
+
+ Q = ggml_reshape_3d(ctx0, Q, d_head, n_head, num_patches);
+ Q = build_rope_2d(gf, ctx0, Q, pos_h, pos_w, hparams.rope_theta);
+ Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
+
+ struct ggml_tensor * K = ggml_mul_mat(ctx0, model.layers[il].k_w, cur);
+
+ K = ggml_reshape_3d(ctx0, K, d_head, n_head, num_patches);
+ K = build_rope_2d(gf, ctx0, K, pos_h, pos_w, hparams.rope_theta);
+ K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
+
+ struct ggml_tensor * V = ggml_mul_mat(ctx0, model.layers[il].v_w, cur);
+
+ V = ggml_reshape_3d(ctx0, V, d_head, n_head, num_patches);
+ V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
+
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+ KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
+
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
+ KQV = ggml_reshape_3d(ctx0, KQV, d_head, num_patches, n_head);
+ KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+
+ cur = ggml_cont_2d(ctx0, KQV, hidden_size, num_patches);
+
+ cur = ggml_mul_mat(ctx0, model.layers[il].o_w, cur);
+ }
+
+ // re-add the layer input, e.g., residual
+ cur = ggml_add(ctx0, cur, embeddings);
+
+ embeddings = cur; // embeddings = residual, cur = hidden_states
+
+ // pre-ffn norm
+ cur = ggml_mul(ctx0, ggml_rms_norm(ctx0, cur, eps), model.layers[il].ln_2_w);
+
+ // feed-forward
+ {
+ ggml_tensor * gate_proj = ggml_mul_mat(ctx0, model.layers[il].ff_gate_w, cur);
+ ggml_tensor * up_proj = ggml_mul_mat(ctx0, model.layers[il].ff_up_w, cur);
+ gate_proj = ggml_silu(ctx0, gate_proj); // pixtral uses silu
+ cur = ggml_mul(ctx0, up_proj, gate_proj);
+ cur = ggml_mul_mat(ctx0, model.layers[il].ff_down_w, cur);
+ }
+
+ // residual 2
+ cur = ggml_add(ctx0, embeddings, cur);
+
+ embeddings = cur;
+ }
+
+ // LlavaMultiModalProjector (with GELU activation)
+ {
+ embeddings = ggml_mul_mat(ctx0, model.mm_1_w, embeddings);
+ embeddings = ggml_add(ctx0, embeddings, model.mm_1_b);
+
+ embeddings = ggml_gelu(ctx0, embeddings);
+ embeddings = ggml_mul_mat(ctx0, model.mm_2_w, embeddings);
+ embeddings = ggml_add(ctx0, embeddings, model.mm_2_b);
+ }
+
+ // arrangement of the [IMG_BREAK] token
+ {
+ // not efficient, but works
+ // the trick is to view the embeddings as a 3D tensor with shape [hidden_size, n_patches_per_row, n_rows]
+ // and then concatenate the [IMG_BREAK] token to the end of each row, aka n_patches_per_row dimension
+ // after the concatenation, we have a tensor with shape [hidden_size, n_patches_per_row + 1, n_rows]
+
+ const int n_embd_text = embeddings->ne[0];
+ const int n_tokens_output = num_patches + n_patches_y - 1; // one [IMG_BREAK] per row, except the last row
+
+ ggml_tensor * cur = ggml_reshape_3d(ctx0, embeddings, n_embd_text, n_patches_x, n_patches_y);
+ ggml_tensor * tok = ggml_new_tensor_3d(ctx0, embeddings->type, n_embd_text, 1, n_patches_y);
+ tok = ggml_scale(ctx0, tok, 0.0); // clear the tensor
+ tok = ggml_add(ctx0, tok, model.token_embd_img_break);
+ cur = ggml_concat(ctx0, cur, tok, 1);
+ embeddings = ggml_view_2d(ctx0, cur,
+ n_embd_text, n_tokens_output,
+ ggml_row_size(cur->type, n_embd_text), 0);
+ }
+
+ // build the graph
+ ggml_build_forward_expand(gf, embeddings);
+
+ return gf;
+}
+
static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) {
if (!ctx->has_vision_encoder) {
LOG_ERR("This gguf file seems to have no vision encoder\n");
{
res = clip_image_build_graph_siglip(ctx, imgs);
} break;
+ case PROJECTOR_TYPE_PIXTRAL:
+ {
+ res = clip_image_build_graph_pixtral(ctx, imgs);
+ } break;
default:
{
// TODO: we should have one build_* function per model
{
get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
} break;
+ case PROJECTOR_TYPE_PIXTRAL:
+ {
+ hparams.rope_theta = 10000.0f;
+ } break;
default:
break;
}
layer.o_w = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "weight"));
layer.ln_1_w = get_tensor(string_format(TN_LN_1, "v", il, "weight"), false);
layer.ln_2_w = get_tensor(string_format(TN_LN_2, "v", il, "weight"), false);
- layer.ff_i_w = get_tensor(string_format(TN_FFN_DOWN, "v", il, "weight"));
- layer.ff_o_w = get_tensor(string_format(TN_FFN_UP, "v", il, "weight"));
layer.k_b = get_tensor(string_format(TN_ATTN_K, "v", il, "bias"), false);
layer.q_b = get_tensor(string_format(TN_ATTN_Q, "v", il, "bias"), false);
layer.v_b = get_tensor(string_format(TN_ATTN_V, "v", il, "bias"), false);
layer.o_b = get_tensor(string_format(TN_ATTN_OUTPUT, "v", il, "bias"), false);
layer.ln_1_b = get_tensor(string_format(TN_LN_1, "v", il, "bias"), false);
layer.ln_2_b = get_tensor(string_format(TN_LN_2, "v", il, "bias"), false);
- layer.ff_i_b = get_tensor(string_format(TN_FFN_DOWN, "v", il, "bias"), false);
- layer.ff_o_b = get_tensor(string_format(TN_FFN_UP, "v", il, "bias"), false);
+
+ // new naming
+ layer.ff_up_w = get_tensor(string_format(TN_FFN_UP, "v", il, "weight"));
+ layer.ff_up_b = get_tensor(string_format(TN_FFN_UP, "v", il, "bias"), false);
+ layer.ff_gate_w = get_tensor(string_format(TN_FFN_GATE, "v", il, "weight"), false);
+ layer.ff_gate_b = get_tensor(string_format(TN_FFN_GATE, "v", il, "bias"), false);
+ layer.ff_down_w = get_tensor(string_format(TN_FFN_DOWN, "v", il, "weight"));
+ layer.ff_down_b = get_tensor(string_format(TN_FFN_DOWN, "v", il, "bias"), false);
+
+ // legacy naming (the in and out is reversed! don't ask me why)
+ layer.ff_i_w = layer.ff_down_w;
+ layer.ff_o_w = layer.ff_up_w;
+ layer.ff_i_b = layer.ff_down_b;
+ layer.ff_o_b = layer.ff_up_b;
}
switch (ctx_clip.proj_type) {
{
vision_model.projection = get_tensor(TN_MM_PROJECTOR);
} break;
+ case PROJECTOR_TYPE_PIXTRAL:
+ {
+ vision_model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight"));
+ vision_model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias"));
+ vision_model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight"));
+ vision_model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias"));
+ // [IMG_BREAK] token embedding
+ vision_model.token_embd_img_break = get_tensor(TN_TOK_IMG_BREAK);
+ } break;
default:
GGML_ASSERT(false && "unknown projector type");
}
}
void alloc_compute_meta() {
- ctx_clip.buf_compute_meta.resize(GGML_DEFAULT_GRAPH_SIZE * ggml_tensor_overhead() + ggml_graph_overhead());
+ ctx_clip.buf_compute_meta.resize(ctx_clip.max_nodes * ggml_tensor_overhead() + ggml_graph_overhead());
// create a fake batch
clip_image_f32_batch batch;
clip_image_f32_ptr img(clip_image_f32_init());
clip_image_size image_size;
- image_size.width = clip_get_image_size(&ctx_clip);
- image_size.height = clip_get_image_size(&ctx_clip);
- int n_patches = clip_get_image_size(&ctx_clip) / image_size.width;
- img->nx = n_patches;
- img->ny = n_patches;
- img->buf.resize(n_patches * image_size.width * image_size.height * 3);
+ image_size.width = ctx_clip.vision_model.hparams.image_size;
+ image_size.height = ctx_clip.vision_model.hparams.image_size;
+ img->nx = image_size.width;
+ img->ny = image_size.height;
+ img->buf.resize(image_size.width * image_size.height * 3);
batch.entries.push_back(std::move(img));
ggml_cgraph * gf = clip_image_build_graph(&ctx_clip, batch, image_size, false);
}
}
+ // calculate the size of the **resized** image, while preserving the aspect ratio
+ // the calculated size will be aligned to the nearest multiple of align_size
+ // if H or W size is larger than max_dimension, it will be resized to max_dimension
+ static clip_image_size calc_size_preserved_ratio(const clip_image_size & inp_size, const int align_size, const int max_dimension) {
+ if (inp_size.width <= 0 || inp_size.height <= 0 || align_size <= 0 || max_dimension <= 0) {
+ return {0, 0};
+ }
+
+ float scale = std::min(1.0f, std::min(static_cast<float>(max_dimension) / inp_size.width,
+ static_cast<float>(max_dimension) / inp_size.height));
+
+ float target_width_f = static_cast<float>(inp_size.width) * scale;
+ float target_height_f = static_cast<float>(inp_size.height) * scale;
+
+ int aligned_width = GGML_PAD((int)target_width_f, align_size);
+ int aligned_height = GGML_PAD((int)target_height_f, align_size);
+
+ return {aligned_width, aligned_height};
+ }
+
private:
static inline int clip(int x, int lower, int upper) {
return std::max(lower, std::min(x, upper));
res_imgs->entries.push_back(std::move(img_f32));
return true;
}
-
- if (ctx->has_glm_projector
+ else if (ctx->has_glm_projector
|| ctx->proj_type == PROJECTOR_TYPE_GEMMA3
|| ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
clip_image_u8 resized_image;
res_imgs->entries.push_back(std::move(img_f32));
return true;
}
+ else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
+ clip_image_u8 resized_image;
+ auto new_size = image_manipulation::calc_size_preserved_ratio(original_size, params.patch_size, params.image_size);
+ image_manipulation::bilinear_resize(*img, resized_image, new_size.width, new_size.height);
+ clip_image_f32_ptr img_f32(clip_image_f32_init());
+ normalize_image_u8_to_f32(resized_image, *img_f32, ctx->image_mean, ctx->image_std);
+ res_imgs->entries.push_back(std::move(img_f32));
+ return true;
+ }
// the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104)
// see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156
n_patches = 256;
} else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
n_patches /= ctx->vision_model.hparams.proj_scale_factor;
+ } else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
+ int n_patches_x = img->nx / params.patch_size;
+ int n_patches_y = img->ny / params.patch_size;
+ n_patches = n_patches_y*n_patches_x + n_patches_y - 1; // + one [IMG_BREAK] per row, except the last row
}
return n_patches;
struct ggml_tensor * inp_raw = ggml_graph_get_tensor(gf, "inp_raw");
float * data = (float *)malloc(ggml_nbytes(inp_raw));
+ // TODO @ngxson : this whole code block is ugly, will need to be refactored
for (size_t i = 0; i < imgs.entries.size(); i++) {
const int nx = imgs.entries[i]->nx;
const int ny = imgs.entries[i]->ny;
- if (!(ctx->has_minicpmv_projector | ctx->has_qwen2vl_merger)) {
+
+ if (ctx->has_glm_projector
+ || ctx->has_llava_projector
+ || ctx->proj_type == PROJECTOR_TYPE_GEMMA3
+ || ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
GGML_ASSERT(nx == image_size && ny == image_size);
}
else if (ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
// do nothing
}
+ else if (ctx->proj_type == PROJECTOR_TYPE_PIXTRAL) {
+ // set the 2D positions
+ int n_patches_per_col = image_size_width / patch_size;
+ std::vector<int> pos_data(num_positions);
+ struct ggml_tensor * pos;
+ // dimension H
+ pos = ggml_graph_get_tensor(gf, "pos_h");
+ for (int i = 0; i < num_positions; i++) {
+ pos_data[i] = i / n_patches_per_col;
+ }
+ ggml_backend_tensor_set(pos, pos_data.data(), 0, ggml_nbytes(pos));
+ // dimension W
+ pos = ggml_graph_get_tensor(gf, "pos_w");
+ for (int i = 0; i < num_positions; i++) {
+ pos_data[i] = i % n_patches_per_col;
+ }
+ ggml_backend_tensor_set(pos, pos_data.data(), 0, ggml_nbytes(pos));
+ }
else {
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
case PROJECTOR_TYPE_LDPV2:
return ctx->vision_model.mm_model_peg_0_b->ne[0];
case PROJECTOR_TYPE_MLP:
+ case PROJECTOR_TYPE_PIXTRAL:
return ctx->vision_model.mm_2_b->ne[0];
case PROJECTOR_TYPE_MLP_NORM:
return ctx->vision_model.mm_3_b->ne[0];