};
struct clip_ctx {
- bool has_text_encoder = false;
- bool has_vision_encoder = false;
bool has_llava_projector = false;
- bool has_minicpmv_projector = false;
- bool has_glm_projector = false;
- bool has_qwen2vl_merger = false;
- int minicpmv_version = 2;
+ int minicpmv_version = 0;
struct clip_vision_model vision_model;
projector_type proj_type = PROJECTOR_TYPE_MLP;
}
};
-static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
+static ggml_cgraph * clip_image_build_graph_siglip(clip_ctx * ctx, const clip_image_f32 & img) {
const auto & model = ctx->vision_model;
const auto & hparams = model.hparams;
- const int image_size = hparams.image_size;
- int image_size_width = image_size;
- int image_size_height = image_size;
+ int image_size_width = img.nx;
+ int image_size_height = img.ny;
- const int patch_size = hparams.patch_size;
- const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
- const int hidden_size = hparams.hidden_size;
- const int n_head = hparams.n_head;
- const int d_head = hidden_size / n_head;
- const int n_layer = hparams.n_layer;
- const float eps = hparams.eps;
-
- GGML_ASSERT(imgs.entries.size() == 1); // batch_size == 1
+ const int patch_size = hparams.patch_size;
+ const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
+ const int hidden_size = hparams.hidden_size;
+ const int n_head = hparams.n_head;
+ const int d_head = hidden_size / n_head;
+ const int n_layer = hparams.n_layer;
+ const float eps = hparams.eps;
struct ggml_init_params params = {
/*.mem_size =*/ ctx->buf_compute_meta.size(),
return cur;
}
-static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_image_f32_batch & imgs) {
+static ggml_cgraph * clip_image_build_graph_pixtral(clip_ctx * ctx, const clip_image_f32 & img) {
const auto & model = ctx->vision_model;
const auto & hparams = model.hparams;
GGML_ASSERT(ctx->proj_type == PROJECTOR_TYPE_PIXTRAL);
- GGML_ASSERT(imgs.entries.size() == 1); // batch_size == 1
- int image_size_width = imgs.entries[0]->nx;
- int image_size_height = imgs.entries[0]->ny;
+ int image_size_width = img.nx;
+ int image_size_height = img.ny;
const int patch_size = hparams.patch_size;
const int n_patches_x = image_size_width / patch_size;
}
static ggml_cgraph * clip_image_build_graph_legacy(clip_ctx * ctx, const clip_image_f32_batch & imgs, struct clip_image_size load_image_size, bool is_inf = false) {
- if (!ctx->has_vision_encoder) {
- LOG_ERR("This gguf file seems to have no vision encoder\n");
- return nullptr;
- }
-
const auto & model = ctx->vision_model;
const auto & hparams = model.hparams;
const int image_size = hparams.image_size;
int image_size_width = image_size;
int image_size_height = image_size;
- if (ctx->has_minicpmv_projector) {
+
+ if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
LOG_DBG("%s: %d %d\n", __func__, load_image_size.width, load_image_size.height);
image_size_width = load_image_size.width;
image_size_height = load_image_size.height;
image_size_height = imgs.entries[0]->ny;
}
}
- else if (ctx->has_qwen2vl_merger) {
+
+ else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
// use the image's native resolution when image is avaible
if (is_inf) {
// if (imgs->data->nx && imgs->data->ny) {
image_size_height = imgs.entries[0]->ny;
}
}
+
const int patch_size = hparams.patch_size;
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
const int patches_w = image_size_width / patch_size;
const int patches_h = image_size_height / patch_size;
const int num_positions = num_patches + (model.class_embedding ? 1 : 0);
- const int num_position_ids = ctx->has_qwen2vl_merger ? num_positions * 4 : num_positions;
+ const int num_position_ids = ctx->proj_type == PROJECTOR_TYPE_QWEN2VL ? num_positions * 4 : num_positions;
const int hidden_size = hparams.hidden_size;
const int n_head = hparams.n_head;
const int d_head = hidden_size / n_head;
const int batch_size = imgs.entries.size();
- if (ctx->has_llava_projector || ctx->has_minicpmv_projector || ctx->has_glm_projector) {
+ if (ctx->has_llava_projector
+ || ctx->proj_type == PROJECTOR_TYPE_MINICPMV
+ || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
GGML_ASSERT(batch_size == 1);
}
struct ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
- if (ctx->has_qwen2vl_merger) {
- GGML_ASSERT(image_size_width % (patch_size * 2) == 0);
+ if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
+ GGML_ASSERT(image_size_width % (patch_size * 2) == 0);
GGML_ASSERT(image_size_height % (patch_size * 2) == 0);
auto inp_1 = ggml_conv_2d(ctx0, model.patch_embeddings_1, inp_raw, patch_size, patch_size, 0, 0, 1, 1);
struct ggml_tensor * embeddings = inp;
struct ggml_tensor * pos_embed = nullptr;
- if (ctx->has_llava_projector) {
- // concat class_embeddings and patch_embeddings
- if (model.class_embedding) {
- embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
- ggml_set_name(embeddings, "embeddings");
- ggml_set_input(embeddings);
- embeddings = ggml_acc(ctx0, embeddings, model.class_embedding,
- embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
- embeddings = ggml_acc(ctx0, embeddings, inp,
- embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
- }
+ // concat class_embeddings and patch_embeddings
+ if (model.class_embedding) {
+ embeddings = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, hidden_size, num_positions, batch_size);
+ embeddings = ggml_scale(ctx0, embeddings, 0.0f); // set to all zeros
+ embeddings = ggml_acc(ctx0, embeddings, model.class_embedding,
+ embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], 0);
+ embeddings = ggml_acc(ctx0, embeddings, inp,
+ embeddings->nb[1], embeddings->nb[2], embeddings->nb[3], model.class_embedding->nb[1]);
}
struct ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids);
ggml_set_name(positions, "positions");
ggml_set_input(positions);
- if (!ctx->has_qwen2vl_merger) { // qwen2vl use rope position embedding
+ if (ctx->proj_type != PROJECTOR_TYPE_QWEN2VL) { // qwen2vl does NOT use learned position embeddings
embeddings =
ggml_add(ctx0, embeddings, ggml_get_rows(ctx0, model.position_embeddings, positions));
}
- if (ctx->has_minicpmv_projector) {
+ if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
int pos_w = image_size_width/patch_size;
int pos_h = image_size_height/patch_size;
if (ctx->minicpmv_version == 2) {
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].q_w, cur), model.layers[il].q_b);
Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_positions, batch_size);
- if (ctx->has_qwen2vl_merger) {
+ if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
Q = ggml_rope_multi(
ctx0, Q, positions, nullptr,
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].k_w, cur), model.layers[il].k_b);
K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
- if (ctx->has_qwen2vl_merger) {
+ if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
K = ggml_rope_multi(
ctx0, K, positions, nullptr,
d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1);
}
}
// minicpmv projector
- else if (ctx->has_minicpmv_projector)
- {
- if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
- struct ggml_tensor * q = model.mm_model_query;
- { // layernorm
- q = ggml_norm(ctx0, q, eps);
- q = ggml_add(ctx0, ggml_mul(ctx0, q, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
+ else if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
+ struct ggml_tensor * q = model.mm_model_query;
+ { // layernorm
+ q = ggml_norm(ctx0, q, eps);
+ q = ggml_add(ctx0, ggml_mul(ctx0, q, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
+ }
+ struct ggml_tensor * v = ggml_mul_mat(ctx0, model.mm_model_kv_proj, embeddings);
+ { // layernorm
+ v = ggml_norm(ctx0, v, eps);
+ v = ggml_add(ctx0, ggml_mul(ctx0, v, model.mm_model_ln_kv_w), model.mm_model_ln_kv_b);
+ }
+ struct ggml_tensor * k;
+ { // position
+ // q = ggml_add(ctx0, q, model.mm_model_pos_embed);
+ k = ggml_add(ctx0, v, pos_embed);
+ }
+
+ { // attention
+ int hidden_size = 4096;
+ const int d_head = 128;
+ int n_head = hidden_size/d_head;
+ int num_query = 96;
+ if (ctx->minicpmv_version == 2) {
+ hidden_size = 4096;
+ n_head = hidden_size/d_head;
+ num_query = 96;
}
- struct ggml_tensor * v = ggml_mul_mat(ctx0, model.mm_model_kv_proj, embeddings);
- { // layernorm
- v = ggml_norm(ctx0, v, eps);
- v = ggml_add(ctx0, ggml_mul(ctx0, v, model.mm_model_ln_kv_w), model.mm_model_ln_kv_b);
+ else if (ctx->minicpmv_version == 3) {
+ hidden_size = 3584;
+ n_head = hidden_size/d_head;
+ num_query = 64;
}
- struct ggml_tensor * k;
- { // position
- // q = ggml_add(ctx0, q, model.mm_model_pos_embed);
- k = ggml_add(ctx0, v, pos_embed);
+ else if (ctx->minicpmv_version == 4) {
+ hidden_size = 3584;
+ n_head = hidden_size/d_head;
+ num_query = 64;
}
- { // attention
- int hidden_size = 4096;
- const int d_head = 128;
- int n_head = hidden_size/d_head;
- int num_query = 96;
- if (ctx->minicpmv_version == 2) {
- hidden_size = 4096;
- n_head = hidden_size/d_head;
- num_query = 96;
- }
- else if (ctx->minicpmv_version == 3) {
- hidden_size = 3584;
- n_head = hidden_size/d_head;
- num_query = 64;
- }
- else if (ctx->minicpmv_version == 4) {
- hidden_size = 3584;
- n_head = hidden_size/d_head;
- num_query = 64;
- }
+ struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b);
+ struct ggml_tensor * K = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k), model.mm_model_attn_k_b);
+ struct ggml_tensor * V = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v), model.mm_model_attn_v_b);
+ // permute
+ Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_query, batch_size);
+ Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
+ Q = ggml_reshape_3d(ctx0, Q, d_head, num_query, n_head * batch_size);
+ K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
+ K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
+ K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
+ V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size);
+ V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
+ V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
+ struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
+ KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
+ struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
+ KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_query, n_head, batch_size);
+ KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
+ KQV = ggml_cont_3d(ctx0, KQV, hidden_size, num_query, batch_size);
- struct ggml_tensor * Q = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_q_w, q), model.mm_model_attn_q_b);
- struct ggml_tensor * K = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_k_w, k), model.mm_model_attn_k_b);
- struct ggml_tensor * V = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_v_w, v), model.mm_model_attn_v_b);
- // permute
- Q = ggml_reshape_4d(ctx0, Q, d_head, n_head, num_query, batch_size);
- Q = ggml_cont(ctx0, ggml_permute(ctx0, Q, 0, 2, 1, 3));
- Q = ggml_reshape_3d(ctx0, Q, d_head, num_query, n_head * batch_size);
- K = ggml_reshape_4d(ctx0, K, d_head, n_head, num_positions, batch_size);
- K = ggml_cont(ctx0, ggml_permute(ctx0, K, 0, 2, 1, 3));
- K = ggml_reshape_3d(ctx0, K, d_head, num_positions, n_head * batch_size);
- V = ggml_reshape_4d(ctx0, V, d_head, n_head, num_positions, batch_size);
- V = ggml_cont(ctx0, ggml_permute(ctx0, V, 1, 2, 0, 3));
- V = ggml_reshape_3d(ctx0, V, num_positions, d_head, n_head * batch_size);
- struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
- KQ = ggml_soft_max_ext(ctx0, KQ, nullptr, 1.0f / sqrtf((float)d_head), 0.0f);
- struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ);
- KQV = ggml_reshape_4d(ctx0, KQV, d_head, num_query, n_head, batch_size);
- KQV = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
- KQV = ggml_cont_3d(ctx0, KQV, hidden_size, num_query, batch_size);
-
- embeddings = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_o_w, KQV), model.mm_model_attn_o_b);
- }
- { // layernorm
- embeddings = ggml_norm(ctx0, embeddings, eps);
- embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_post_w), model.mm_model_ln_post_b);
- }
- embeddings = ggml_mul_mat(ctx0, model.mm_model_proj, embeddings);
+ embeddings = ggml_add(ctx0, ggml_mul_mat(ctx0, model.mm_model_attn_o_w, KQV), model.mm_model_attn_o_b);
}
- else {
- GGML_ASSERT(false);
+ { // layernorm
+ embeddings = ggml_norm(ctx0, embeddings, eps);
+ embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_post_w), model.mm_model_ln_post_b);
}
+ embeddings = ggml_mul_mat(ctx0, model.mm_model_proj, embeddings);
}
+
// glm projector
- else if (ctx->has_glm_projector) {
- if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
- size_t gridsz = (size_t)sqrt(embeddings->ne[1]);
- embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings,1,0,2,3));
- embeddings = ggml_reshape_3d(ctx0, embeddings, gridsz, gridsz, embeddings->ne[1]);
- embeddings = ggml_conv_2d(ctx0, model.mm_model_adapter_conv_w, embeddings, 2, 2, 0, 0, 1, 1);
- embeddings = ggml_reshape_3d(ctx0, embeddings,embeddings->ne[0]*embeddings->ne[1] , embeddings->ne[2], batch_size);
- embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings, 1, 0, 2, 3));
- embeddings = ggml_add(ctx0, embeddings, model.mm_model_adapter_conv_b);
- //GLU
- {
- embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
- embeddings = ggml_norm(ctx0, embeddings, eps);
- embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
- embeddings = ggml_gelu_inplace(ctx0, embeddings);
- struct ggml_tensor * x = embeddings;
- embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings);
- x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x);
- embeddings = ggml_silu_inplace(ctx0, embeddings);
- embeddings = ggml_mul(ctx0, embeddings,x);
- embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings);
- }
- } else {
- GGML_ABORT("fatal error");
+ else if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
+ size_t gridsz = (size_t)sqrt(embeddings->ne[1]);
+ embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings,1,0,2,3));
+ embeddings = ggml_reshape_3d(ctx0, embeddings, gridsz, gridsz, embeddings->ne[1]);
+ embeddings = ggml_conv_2d(ctx0, model.mm_model_adapter_conv_w, embeddings, 2, 2, 0, 0, 1, 1);
+ embeddings = ggml_reshape_3d(ctx0, embeddings,embeddings->ne[0]*embeddings->ne[1] , embeddings->ne[2], batch_size);
+ embeddings = ggml_cont(ctx0, ggml_permute(ctx0,embeddings, 1, 0, 2, 3));
+ embeddings = ggml_add(ctx0, embeddings, model.mm_model_adapter_conv_b);
+ // GLU
+ {
+ embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_0_w, embeddings);
+ embeddings = ggml_norm(ctx0, embeddings, eps);
+ embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.mm_model_ln_q_w), model.mm_model_ln_q_b);
+ embeddings = ggml_gelu_inplace(ctx0, embeddings);
+ struct ggml_tensor * x = embeddings;
+ embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_2_w, embeddings);
+ x = ggml_mul_mat(ctx0, model.mm_model_mlp_1_w,x);
+ embeddings = ggml_silu_inplace(ctx0, embeddings);
+ embeddings = ggml_mul(ctx0, embeddings,x);
+ embeddings = ggml_mul_mat(ctx0, model.mm_model_mlp_3_w, embeddings);
}
}
- else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
+
+ else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
embeddings = ggml_reshape_3d(ctx0, embeddings, hidden_size * 4, num_positions / 4, batch_size);
embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings);
case PROJECTOR_TYPE_GEMMA3:
case PROJECTOR_TYPE_IDEFICS3:
{
- res = clip_image_build_graph_siglip(ctx, imgs);
+ GGML_ASSERT(imgs.entries.size() == 1);
+ res = clip_image_build_graph_siglip(ctx, *imgs.entries[0]);
} break;
case PROJECTOR_TYPE_PIXTRAL:
{
- res = clip_image_build_graph_pixtral(ctx, imgs);
+ GGML_ASSERT(imgs.entries.size() == 1);
+ res = clip_image_build_graph_pixtral(ctx, *imgs.entries[0]);
} break;
default:
{
auto & hparams = ctx_clip.vision_model.hparams;
// projector type
+ std::string proj_type;
{
- std::string proj_type;
get_string(KEY_PROJ_TYPE, proj_type, false);
if (!proj_type.empty()) {
ctx_clip.proj_type = clip_projector_type_from_string(proj_type);
// other hparams
{
- get_bool(KEY_HAS_TEXT_ENC, ctx_clip.has_text_encoder, false);
- get_bool(KEY_HAS_VIS_ENC, ctx_clip.has_vision_encoder, false);
- GGML_ASSERT(ctx_clip.has_vision_encoder);
- GGML_ASSERT(!ctx_clip.has_text_encoder);
-
- // legacy keys, use KEY_PROJ_TYPE instead
- get_bool(KEY_HAS_LLAVA_PROJ, ctx_clip.has_llava_projector, false);
- get_bool(KEY_HAS_MINICPMV_PROJ, ctx_clip.has_minicpmv_projector, false);
get_i32(KEY_MINICPMV_VERSION, ctx_clip.minicpmv_version, false);
- get_bool(KEY_HAS_GLM_PROJ, ctx_clip.has_glm_projector, false);
- get_bool(KEY_HAS_QWEN2VL_MERGER, ctx_clip.has_qwen2vl_merger, false);
- // !!! do NOT extend the list above, use KEY_PROJ_TYPE instead
get_bool(KEY_USE_GELU, ctx_clip.use_gelu, false);
get_bool(KEY_USE_SILU, ctx_clip.use_silu, false);
- get_u32(string_format(KEY_N_EMBD, "vision"), hparams.hidden_size);
- get_u32(string_format(KEY_N_HEAD, "vision"), hparams.n_head);
- get_u32(string_format(KEY_N_FF, "vision"), hparams.n_intermediate);
- get_u32(string_format(KEY_N_BLOCK, "vision"), hparams.n_layer);
- get_u32(string_format(KEY_PROJ_DIM, "vision"), hparams.projection_dim);
- get_f32(string_format(KEY_LAYER_NORM_EPS, "vision"), hparams.eps);
- get_u32(KEY_IMAGE_SIZE, hparams.image_size);
- get_u32(KEY_PATCH_SIZE, hparams.patch_size);
- get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false);
+ get_u32(KEY_N_EMBD, hparams.hidden_size);
+ get_u32(KEY_N_HEAD, hparams.n_head);
+ get_u32(KEY_N_FF, hparams.n_intermediate);
+ get_u32(KEY_N_BLOCK, hparams.n_layer);
+ get_u32(KEY_PROJ_DIM, hparams.projection_dim);
+ get_f32(KEY_LAYER_NORM_EPS, hparams.eps);
+ get_u32(KEY_IMAGE_SIZE, hparams.image_size);
+ get_u32(KEY_PATCH_SIZE, hparams.patch_size);
+ get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false);
get_arr_int(KEY_IMAGE_GRID_PINPOINTS, hparams.image_grid_pinpoints, false);
+ ctx_clip.has_llava_projector = ctx_clip.proj_type == PROJECTOR_TYPE_MLP
+ || ctx_clip.proj_type == PROJECTOR_TYPE_MLP_NORM
+ || ctx_clip.proj_type == PROJECTOR_TYPE_LDP
+ || ctx_clip.proj_type == PROJECTOR_TYPE_LDPV2;
+
{
std::string mm_patch_merge_type;
get_string(KEY_MM_PATCH_MERGE_TYPE, mm_patch_merge_type, false);
for (auto & layer : vision_feature_layer) {
hparams.vision_feature_layer.insert(layer);
}
+
// Calculate the deepest feature layer based on hparams and projector type
- ctx_clip.max_feature_layer = get_deepest_feature_layer(&ctx_clip);
+ // NOTE: This is only used by build_graph_legacy()
+ {
+ // Get the index of the second to last layer; this is the default for models that have a llava projector
+ int n_layer = hparams.n_layer - 1;
+ int deepest_feature_layer = -1;
+
+ if (ctx_clip.proj_type == PROJECTOR_TYPE_MINICPMV
+ || ctx_clip.proj_type == PROJECTOR_TYPE_GLM_EDGE
+ || ctx_clip.proj_type == PROJECTOR_TYPE_QWEN2VL) {
+ n_layer += 1;
+ }
+
+ // If we set explicit vision feature layers, only go up to the deepest one
+ // NOTE: only used by granite-vision models for now
+ for (const auto & feature_layer : hparams.vision_feature_layer) {
+ if (feature_layer > deepest_feature_layer) {
+ deepest_feature_layer = feature_layer;
+ }
+ }
+ ctx_clip.max_feature_layer = deepest_feature_layer < 0 ? n_layer : deepest_feature_layer;
+ }
+
+ // model-specific params
+ switch (ctx_clip.proj_type) {
+ case PROJECTOR_TYPE_MINICPMV:
+ {
+ if (ctx_clip.minicpmv_version == 0) {
+ ctx_clip.minicpmv_version = 2; // default to 2 if not set
+ }
+ } break;
+ case PROJECTOR_TYPE_IDEFICS3:
+ {
+ get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
+ } break;
+ case PROJECTOR_TYPE_PIXTRAL:
+ {
+ hparams.rope_theta = 10000.0f;
+ } break;
+ default:
+ break;
+ }
- LOG_INF("%s: text_encoder: %d\n", __func__, ctx_clip.has_text_encoder);
- LOG_INF("%s: vision_encoder: %d\n", __func__, ctx_clip.has_vision_encoder);
- LOG_INF("%s: llava_projector: %d\n", __func__, ctx_clip.has_llava_projector);
- LOG_INF("%s: minicpmv_projector: %d\n", __func__, ctx_clip.has_minicpmv_projector);
+ LOG_INF("%s: projector: %s\n", __func__, proj_type.c_str());
+ LOG_INF("%s: has_llava_proj: %d\n", __func__, ctx_clip.has_llava_projector);
LOG_INF("%s: minicpmv_version: %d\n", __func__, ctx_clip.minicpmv_version);
- LOG_INF("%s: glm_projector: %d\n", __func__, ctx_clip.has_glm_projector);
LOG_INF("%s: model size: %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0);
LOG_INF("%s: metadata size: %.2f MiB\n", __func__, ggml_get_mem_size(ctx_meta.get()) / 1024.0 / 1024.0);
}
-
- // model-specific params
- switch (ctx_clip.proj_type) {
- case PROJECTOR_TYPE_IDEFICS3:
- {
- get_u32(KEY_PROJ_SCALE_FACTOR, hparams.proj_scale_factor, false);
- } break;
- case PROJECTOR_TYPE_PIXTRAL:
- {
- hparams.rope_theta = 10000.0f;
- } break;
- default:
- break;
- }
}
void load_tensors() {
vision_model.patch_bias = get_tensor(TN_PATCH_BIAS, false);
vision_model.patch_embeddings_0 = get_tensor(TN_PATCH_EMBD, false);
vision_model.patch_embeddings_1 = get_tensor(TN_PATCH_EMBD_1, false);
- if (vision_model.patch_embeddings_1 == nullptr) {
- ctx_clip.has_qwen2vl_merger = false;
- }
vision_model.position_embeddings = get_tensor(string_format(TN_POS_EMBD, "v"), false);
vision_model.mm_model_peg_0_w = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "weight"));
vision_model.mm_model_peg_0_b = get_tensor(string_format(TN_MVLM_PROJ_PEG, 0, "bias"));
} break;
- case PROJECTOR_TYPE_RESAMPLER:
+ case PROJECTOR_TYPE_MINICPMV:
{
// vision_model.mm_model_pos_embed = get_tensor(new_clip->ctx_data, TN_MINICPMV_POS_EMBD);
vision_model.mm_model_pos_embed_k = get_tensor(TN_MINICPMV_POS_EMBD_K);
vision_model.mm_model_mlp_2_w = get_tensor(string_format(TN_GLM_ADAPTER_GATE,"weight"));
vision_model.mm_model_mlp_3_w = get_tensor(string_format(TN_GLM_ADAPTER_D_4H_2_H,"weight"));
} break;
- case PROJECTOR_TYPE_MERGER:
+ case PROJECTOR_TYPE_QWEN2VL:
{
vision_model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight"));
vision_model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias"));
// returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector
// res_imgs memory is being allocated here, previous allocations will be freed if found
bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) {
- if (!ctx->has_vision_encoder) {
- LOG_ERR("%s: This gguf file seems to have no vision encoder\n", __func__);
- return false;
- }
-
clip_image_size original_size{img->nx, img->ny};
bool pad_to_square = true;
auto & params = ctx->vision_model.hparams;
}
return true;
}
- else if (ctx->has_qwen2vl_merger) {
+ else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
clip_image_u8 resized;
auto patch_size = clip_get_patch_size(ctx) * 2;
int nx = ceil((float)img->nx / patch_size) * patch_size;
res_imgs->entries.push_back(std::move(img_f32));
return true;
}
- else if (ctx->has_glm_projector
+ else if (ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE
|| ctx->proj_type == PROJECTOR_TYPE_GEMMA3
|| ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
clip_image_u8 resized_image;
if (ctx->proj_type == PROJECTOR_TYPE_LDP || ctx->proj_type == PROJECTOR_TYPE_LDPV2 || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
n_patches /= 4;
- } else if (ctx->proj_type == PROJECTOR_TYPE_RESAMPLER) {
+ } else if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
if (ctx->minicpmv_version == 2) {
n_patches = 96;
}
else if (ctx->minicpmv_version == 4) {
n_patches = 64;
}
- } else if (ctx->proj_type == PROJECTOR_TYPE_MERGER) {
+ else {
+ GGML_ABORT("Unknown minicpmv version");
+ }
+ } else if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
int patch_size = params.patch_size * 2;
int x_patch = img->nx / patch_size + (int)(img->nx % patch_size > 0);
int y_patch = img->ny / patch_size + (int)(img->ny % patch_size > 0);
}
bool clip_image_encode(struct clip_ctx * ctx, const int n_threads, clip_image_f32 * img, float * vec) {
- if (!ctx->has_vision_encoder) {
- LOG_ERR("%s: This gguf file seems to have no vision encoder\n", __func__);
- return false;
- }
-
clip_image_f32_batch imgs;
clip_image_f32_ptr img_copy(clip_image_f32_init());
*img_copy = *img;
bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_image_f32_batch * imgs_c_ptr, float * vec) {
const clip_image_f32_batch & imgs = *imgs_c_ptr;
-
- if (!ctx->has_vision_encoder) {
- LOG_ERR("%s: This gguf file seems to have no vision encoder\n", __func__);
- return false;
- }
-
int batch_size = imgs.entries.size();
- if (ctx->has_llava_projector) {
- GGML_ASSERT(batch_size == 1); // TODO: support multiple images
- }
- if (ctx->has_minicpmv_projector) {
- GGML_ASSERT(batch_size == 1);
- }
- if (ctx->has_glm_projector) {
+
+ if (ctx->has_llava_projector
+ || ctx->proj_type == PROJECTOR_TYPE_MINICPMV
+ || ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE) {
GGML_ASSERT(batch_size == 1);
}
ggml_backend_sched_alloc_graph(ctx->sched.get(), gf);
// set inputs
- const auto & model = ctx->vision_model;
+ const auto & model = ctx->vision_model;
const auto & hparams = model.hparams;
- // TODO @ngxson : this is ugly, need to refactor later
- bool support_dynamic_size = ctx->has_minicpmv_projector
- || ctx->has_qwen2vl_merger
- || ctx->proj_type == PROJECTOR_TYPE_PIXTRAL;
+ const int image_size_width = imgs.entries[0]->nx;
+ const int image_size_height = imgs.entries[0]->ny;
- const int image_size = hparams.image_size;
- int image_size_width = image_size;
- int image_size_height = image_size;
- if (support_dynamic_size) {
- image_size_width = imgs.entries[0]->nx;
- image_size_height = imgs.entries[0]->ny;
- }
const int patch_size = hparams.patch_size;
const int num_patches = ((image_size_width / patch_size) * (image_size_height / patch_size));
const int num_positions = num_patches + (model.class_embedding ? 1 : 0);
for (size_t i = 0; i < imgs.entries.size(); i++) {
const int nx = imgs.entries[i]->nx;
const int ny = imgs.entries[i]->ny;
-
- if (ctx->has_glm_projector
- || ctx->has_llava_projector
- || ctx->proj_type == PROJECTOR_TYPE_GEMMA3
- || ctx->proj_type == PROJECTOR_TYPE_IDEFICS3) {
- GGML_ASSERT(nx == image_size && ny == image_size);
- }
-
const int n = nx * ny;
for (int b = 0; b < batch_size; b++) {
}
ggml_backend_tensor_set(inp_raw, data, 0, ggml_nbytes(inp_raw));
}
- if (ctx->has_minicpmv_projector) {
+
+ if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
{
// inspired from siglip:
// -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit
// -> https://huggingface.co/HuggingFaceM4/siglip-so400m-14-980-flash-attn2-navit/blob/d66538faeba44480d0bfaa42145eef26f9423199/modeling_siglip.py#L316
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
- int* positions_data = (int*)malloc(ggml_nbytes(positions));
+ std::vector<int> pos_data(ggml_nelements(positions));
+ int * data = pos_data.data();
int bucket_coords_h[1024];
int bucket_coords_w[1024];
for (int i = 0; i < pos_h; i++){
}
for (int i = 0, id = 0; i < pos_h; i++){
for (int j = 0; j < pos_w; j++){
- positions_data[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j];
+ data[id++] = bucket_coords_h[i]*70 + bucket_coords_w[j];
}
}
- ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
- free(positions_data);
+ ggml_backend_tensor_set(positions, data, 0, ggml_nbytes(positions));
}
{
else if (ctx->minicpmv_version == 4) {
embed_dim = 3584;
}
+ else {
+ GGML_ABORT("Unknown minicpmv version");
+ }
+
+ // TODO @ngxson : this is very inefficient, can we do this using ggml_sin and ggml_cos?
auto pos_embed_t = get_2d_sincos_pos_embed(embed_dim, std::make_pair(pos_w, pos_h));
- float * pos_embed_data = (float *)malloc(ggml_nbytes(pos_embed));
- for(int i=0;i < pos_w * pos_h; ++i){
- for(int j=0; j < embed_dim; ++j){
- pos_embed_data[i * embed_dim + j] = pos_embed_t[i][j];
+ std::vector<float> pos_data(ggml_nelements(pos_embed));
+ float * data = pos_data.data();
+ for(int i = 0; i < pos_w * pos_h; ++i){
+ for(int j = 0; j < embed_dim; ++j){
+ data[i * embed_dim + j] = pos_embed_t[i][j];
}
}
- ggml_backend_tensor_set(pos_embed, pos_embed_data, 0, ggml_nbytes(pos_embed));
- free(pos_embed_data);
+ ggml_backend_tensor_set(pos_embed, data, 0, ggml_nbytes(pos_embed));
}
}
else {
- if (model.class_embedding) {
- struct ggml_tensor * embeddings = ggml_graph_get_tensor(gf, "embeddings");
+ // non-minicpmv models
- void* zero_mem = malloc(ggml_nbytes(embeddings));
- memset(zero_mem, 0, ggml_nbytes(embeddings));
- ggml_backend_tensor_set(embeddings, zero_mem, 0, ggml_nbytes(embeddings));
- free(zero_mem);
- }
-
- if (ctx->has_qwen2vl_merger) {
+ if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL) {
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
const int pw = image_size_width / patch_size;
ggml_backend_tensor_set(pos, pos_data.data(), 0, ggml_nbytes(pos));
}
else {
+ // llava and other models
struct ggml_tensor * positions = ggml_graph_get_tensor(gf, "positions");
int* positions_data = (int*)malloc(ggml_nbytes(positions));
ggml_backend_tensor_set(positions, positions_data, 0, ggml_nbytes(positions));
free(positions_data);
- if (!ctx->has_glm_projector) {
+ if (ctx->proj_type != PROJECTOR_TYPE_GLM_EDGE) {
struct ggml_tensor * patches = ggml_graph_get_tensor(gf, "patches");
// The patches vector is used to get rows to index into the embeds with;
// we should skip dim 0 only if we have CLS to avoid going out of bounds
return ctx->vision_model.mm_2_b->ne[0];
case PROJECTOR_TYPE_MLP_NORM:
return ctx->vision_model.mm_3_b->ne[0];
- case PROJECTOR_TYPE_RESAMPLER:
+ case PROJECTOR_TYPE_MINICPMV:
if (ctx->minicpmv_version == 2) {
return 4096;
} else if (ctx->minicpmv_version == 3) {
} else if (ctx->minicpmv_version == 4) {
return 3584;
}
- break; // Should not happen if version is valid
+ GGML_ABORT("Unknown minicpmv version");
case PROJECTOR_TYPE_GLM_EDGE:
return ctx->vision_model.mm_model_mlp_3_w->ne[1];
- case PROJECTOR_TYPE_MERGER:
+ case PROJECTOR_TYPE_QWEN2VL:
return ctx->vision_model.mm_1_b->ne[0];
case PROJECTOR_TYPE_GEMMA3:
return ctx->vision_model.mm_input_proj_w->ne[0];
case PROJECTOR_TYPE_IDEFICS3:
return ctx->vision_model.projection->ne[1];
default:
- break; // Fall through to throw
+ GGML_ABORT("Unknown projector type");
}
-
- std::string proj_type = PROJECTOR_TYPE_NAMES[ctx->proj_type];
- throw std::runtime_error(string_format("%s: don't support projector with: %s currently\n", __func__, proj_type.c_str()));
}
int clip_is_minicpmv(const struct clip_ctx * ctx) {
- if (ctx->has_minicpmv_projector) {
+ if (ctx->proj_type == PROJECTOR_TYPE_MINICPMV) {
return ctx->minicpmv_version;
}
return 0;
}
bool clip_is_glm(const struct clip_ctx * ctx) {
- return ctx->has_glm_projector;
+ return ctx->proj_type == PROJECTOR_TYPE_GLM_EDGE;
}
bool clip_is_qwen2vl(const struct clip_ctx * ctx) {
- return ctx->has_qwen2vl_merger;
+ return ctx->proj_type == PROJECTOR_TYPE_QWEN2VL;
}
bool clip_is_llava(const struct clip_ctx * ctx) {
return ctx->proj_type == PROJECTOR_TYPE_GEMMA3;
}
-// Determine the number of encoder layers to iterate over
-int get_deepest_feature_layer(const struct clip_ctx * ctx) {
- // Get the index of the second to last layer; this is the
- // default for models that have a llava projector
- const auto & hparams = ctx->vision_model.hparams;
- int n_layer = hparams.n_layer - 1;
- int deepest_feature_layer = -1;
-
- // Handle other projectors; incrementing here indicates that we
- // should use the last encoder layer for the vision features.
- if (ctx->has_minicpmv_projector || ctx->has_glm_projector || ctx->has_qwen2vl_merger) {
- n_layer += 1;
- }
-
- // If we set explicit vision feature layers, only go up to the deepest one
- for (const auto & feature_layer : hparams.vision_feature_layer) {
- if (feature_layer > deepest_feature_layer) {
- deepest_feature_layer = feature_layer;
- }
- }
- return deepest_feature_layer < 0 ? n_layer : deepest_feature_layer;
-}
-
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {
clip_image_f32 clip_img;
clip_img.buf.resize(h * w * 3);